diff --git a/records/071825/record.txt b/records/071825/record.txt new file mode 100644 index 000000000..ce9fc4d72 --- /dev/null +++ b/records/071825/record.txt @@ -0,0 +1,2848 @@ +import os +import sys +with open(sys.argv[0]) as f: + code = f.read() # read the code of this file ASAP, for logging +import uuid +import time +import copy +import glob +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +import torch +torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems +from torch import Tensor, nn +import torch.nn.functional as F +import torch.distributed as dist +# use of FlexAttention contributed by @KoszarskyB +from torch.nn.attention.flex_attention import BlockMask, flex_attention +#torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl + +# ----------------------------------------------------------------------------- +# Custom operators: FP8 matmul by @YouJiacheng + +@torch.library.custom_op("nanogpt::mm", mutates_args=()) +def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]: + @torch.compile + def impl(x: Tensor, w: Tensor): + assert x.is_contiguous() and w.is_contiguous() + x_f8 = x.div(x_s).to(torch.float8_e4m3fn) + w_f8 = w.div(w_s).to(torch.float8_e4m3fn) + out = torch._scaled_mm( + x_f8, + w_f8.T, + out_dtype=torch.bfloat16, + scale_a=x.new_tensor(x_s, dtype=torch.float32), + scale_b=x.new_tensor(w_s, dtype=torch.float32), + use_fast_accum=True, + ) + return out, x_f8, w_f8 + + return impl(x, w) + +@mm_op.register_fake +def _(x: Tensor, w: Tensor, *_): + assert x.ndim == w.ndim == 2 + assert x.shape[1] == w.shape[1] + assert x.device == w.device + assert x.is_contiguous() and w.is_contiguous() + return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) + +@torch.library.custom_op("nanogpt::mm_backward", mutates_args=()) +def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]: + @torch.compile + def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): + assert grad.is_contiguous() + x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) + w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) + grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) + grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) + grad_x = torch._scaled_mm( + grad_f8, + w_f8.T.contiguous().T, + out_dtype=torch.bfloat16, + scale_a=grad_inv_s, + scale_b=w_inv_s, + use_fast_accum=False, + ) + # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) + grad_w = torch._scaled_mm( + x_f8.T.contiguous(), + grad_f8.T.contiguous().T, + out_dtype=torch.float32, + scale_a=x_inv_s, + scale_b=grad_inv_s, + use_fast_accum=False, + ).T + return grad_x, grad_w + + return impl(g, x_f8, w_f8) + +@mm_backward_op.register_fake +def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): + return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) + +def backward(ctx, grad_out: Tensor, *_): + x_f8, w_f8 = ctx.saved_tensors + x_s, w_s, grad_s = ctx.scales + grad_x, grad_w = torch.ops.nanogpt.mm_backward( + grad_out, x_f8, w_f8, x_s, w_s, grad_s + ) + return grad_x, grad_w, None, None, None + +def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): + *_, x_s, w_s, grad_s = inputs + _, x_f8, w_f8 = output + ctx.save_for_backward(x_f8, w_f8) + ctx.scales = x_s, w_s, grad_s + ctx.set_materialize_grads(False) + +mm_op.register_autograd(backward, setup_context=setup_context) + +# ----------------------------------------------------------------------------- +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): + """ + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(-2) > G.size(-1): + X = X.mT + + # Ensure spectral norm is at most 1 + X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + + # Perform the NS iterations + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies + + if G.size(-2) > G.size(-1): + X = X.mT + return X + +# ----------------------------------------------------------------------------- +# Muon optimizer + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + https://kellerjordan.github.io/posts/muon/ + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Warning: This optimizer should not be used for the embedding layer, the final fully connected layer, + or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW). + """ + def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + + @torch.no_grad() + def step(self): + # Efficient systems-wise implementation of step developed by @YouJiacheng, + # @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad, + # @ryanyang0, and @vagrawal. + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + grad_pad = [param.grad for param in params] + [torch.zeros_like(params[-1])] * world_size + for base_i in range(0, len(params), world_size): + if base_i + rank < len(params): + grad = params[base_i + rank].grad + # This gives strange dynamo warnings + reduce_scatter_futures.append(dist.reduce_scatter(grad, grad_pad[base_i:base_i + world_size], op=dist.ReduceOp.AVG, async_op=True).get_future()) + + idx = 0 + for group in self.param_groups: + params: list[Tensor] = group["params"] + params_pad = params + [torch.empty_like(params[-1])] * world_size + momentum = group["momentum"] + for base_i in range(0, len(params), world_size): + reduce_scatter_futures[idx].wait() + if base_i + rank < len(params): + p = params[base_i + rank] + grad = p.grad + eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0) + eff_weight_decay = group["lr"] * group["weight_decay"] * getattr(p, "wd_mul", 1.0) + state = self.state[p] + if len(state) == 0: + state["momentum_buffer"] = torch.zeros_like(grad) + momentum_buffer = state["momentum_buffer"] + p.mul_(1 - eff_weight_decay) + momentum_buffer.lerp_(grad, 1 - momentum) + grad = grad.lerp_(momentum_buffer, momentum) + v = newton_schulz_triton(grad) + p.add_(other=v, alpha=-eff_lr) + idx += 1 + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +class DistAdam(torch.optim.Optimizer): + def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + params = list(params) + sizes = {p.shape for p in params} + # create one buffer per unique parameter-size + param_groups = [] + for size in sizes: + group_params = [p for p in params if p.shape == size] + param_groups.append(dict(params=group_params)) + super().__init__(param_groups, defaults) + # DistributedAdam implementation by @vagrawal + + @torch.compile + @torch.no_grad() + def step(self): + rank = dist.get_rank() + world_size = dist.get_world_size() + reduce_scatter_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] + grad_slices = [] + for group in self.param_groups: + params: list[Tensor] = group["params"] + grad = torch.empty_like(params[-1]) + for base_i in range(len(params)): + grad = params[base_i].grad + rank_size = grad.shape[0] // world_size + grad_slice = torch.empty_like(grad[:rank_size]) + reduce_scatter_futures.append(dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()) + grad_slices.append(grad_slice) + + idx = 0 + for group in self.param_groups: + beta1, beta2 = group['betas'] + eps = group['eps'] + wd = group['weight_decay'] + params = group['params'] + for base in range(len(params)): + reduce_scatter_futures[idx].wait() + p = params[base] + rank_size = p.shape[0] // world_size + p_slice = p[rank * rank_size:(rank + 1) * rank_size] + lr = group['lr'] * getattr(p, "lr_mul", 1.0) + state = self.state[p] + g_slice = grad_slices[idx] + # State init + if not state: + state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device) + state['exp_avg'] = torch.zeros_like(p_slice) + state['exp_avg_sq'] = torch.zeros_like(p_slice) + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + state['step'] += 1 + t = state['step'] + # weight decay + if wd != 0: + eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0) + p_slice.mul_(1 - eff_weight_decay) + # update running averages + exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2) + # bias corrections + bias1 = 1 - beta1 ** t + bias2 = 1 - beta2 ** t + # compute step + denom = exp_avg_sq.sqrt().add_(eps) + step_size = lr * (torch.sqrt(bias2) / bias1) + update = exp_avg.div(denom).mul_(step_size) + p_slice.add_(other=update, alpha=-1.0) + idx += 1 + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() + +# ----------------------------------------------------------------------------- +# PyTorch nn.Module definitions for the model + +def norm(x: Tensor): + return F.rms_norm(x, (x.size(-1),)) + +class CastedLinear(nn.Linear): + def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): + super().__init__(in_features, out_features, bias=False) + self.use_fp8 = use_fp8 + self.x_s = x_s + self.w_s = w_s + self.grad_s = grad_s + + def reset_parameters(self) -> None: + std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) + bound = (3 ** 0.5) * std + with torch.no_grad(): + self.weight.uniform_(-bound, bound) + + def forward(self, x: Tensor): + if self.use_fp8 and self.training: + _x = x.flatten(0, -2) + out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] + return out.reshape(*x.shape[:-1], -1) + else: + return F.linear(x, self.weight.type_as(x)) + +class Rotary(nn.Module): + def __init__(self, dim: int, max_seq_len: int): + super().__init__() + # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) + angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) + angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) + t = torch.arange(max_seq_len, dtype=torch.float32) + theta = torch.einsum("i,j -> ij", t, angular_freq) + self.cos = nn.Buffer(theta.cos(), persistent=False) + self.sin = nn.Buffer(theta.sin(), persistent=False) + + def forward(self, x_BTHD: Tensor): + assert self.cos.size(0) >= x_BTHD.size(-3) + cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] + x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) + y1 = x1 * cos + x2 * sin + y2 = x1 * (-sin) + x2 * cos + return torch.cat((y1, y2), 3).type_as(x_BTHD) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng + # https://x.com/hi_tysam/status/1879699187107033311 + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero + self.rotary = Rotary(head_dim, max_seq_len) + # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun + # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 + self.attn_scale = 0.12 + + def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): + B, T = x.size(0), x.size(1) # batch size, sequence length + assert B == 1, "Must use batch size = 1 for FlexAttention" + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k = norm(q), norm(k) # QK norm @Grad62304977 + q, k = self.rotary(q), self.rotary(k) + if ve is not None: + v = lambdas[0] * v + lambdas[1] * ve.view_as(v) # @KoszarskyB & @Grad62304977 + else: # skip mid-layers token value embeddings by @YouJiacheng + v = lambdas[0] * v + y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2) + y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side + y = F.linear(y, self.qkvo_w[3].type_as(y)) + return y + +class MLP(nn.Module): + def __init__(self, dim: int): + super().__init__() + hdim = 4 * dim + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 + + def forward(self, x: Tensor): + x = F.linear(x, self.c_fc.T.type_as(x)) + x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 + x = F.linear(x, self.c_proj.type_as(x)) + return x + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): + super().__init__() + # skip attention of blocks.7 (the 8th layer) by @YouJiacheng + self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None + self.mlp = MLP(dim) + + def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, lambdas: Tensor, sa_lambdas: Tensor, block_mask: BlockMask): + x = lambdas[0] * x + lambdas[1] * x0 + if self.attn is not None: + x = x + self.attn(norm(x), ve, sa_lambdas, block_mask) + x = x + self.mlp(norm(x)) + return x + +# ----------------------------------------------------------------------------- +# The main model + +def next_multiple_of_n(v: float | int, *, n: int): + return next(x for x in range(n, int(v) + 1 + n, n) if x >= v) + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): + super().__init__() + vocab_size = next_multiple_of_n(vocab_size, n=128) + self.embed = nn.Embedding(vocab_size, model_dim) + # token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual implementation following https://arxiv.org/abs/2410.17897 + # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) + # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. + # suggested to me by @Grad62304977. this originates from Karpathy's experiments. + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448) + self.lm_head.weight.detach().zero_() # @Grad62304977 + # Add learnable skip connection weights for decoder layers + assert num_layers % 2 == 0 + pad = (-num_layers * 5) % dist.get_world_size() + self.scalars = nn.Parameter(torch.cat([ + torch.ones(num_layers), # skip_weights + *[torch.tensor([1.0, 0.0]) for _ in range(num_layers)], # block lambdas + *[torch.tensor([0.5, 0.5]) for _ in range(num_layers)], # SA lambdas + torch.ones(pad), + ])) + # set learning rates + for param in self.embed.parameters(): + param.lr_mul = 75. + for param in self.value_embeds.parameters(): + param.lr_mul = 75. + self.lm_head.weight.lr_mul = 27.5 + self.scalars.lr_mul = 5.0 + + def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): + BLOCK_SIZE = 128 + docs = (input_seq == 50256).cumsum(0) + + def document_causal(b, h, q_idx, kv_idx): + causal_mask = q_idx >= kv_idx + document_mask = docs[q_idx] == docs[kv_idx] + return causal_mask & document_mask + + def dense_to_ordered(dense_blockmask: Tensor): + num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) + indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) + return num_blocks[None, None].contiguous(), indices[None, None].contiguous() + + # manual block mask creation by @YouJiacheng + assert len(input_seq) % BLOCK_SIZE == 0 + NUM_BLOCKS = len(input_seq) // BLOCK_SIZE + block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda") + causal_blockmask_any = block_idx[:, None] >= block_idx + causal_blockmask_all = block_idx[:, None] > block_idx + docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() + docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() + document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low) + document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low) + blockmask_any = causal_blockmask_any & document_blockmask_any + blockmask_all = causal_blockmask_all & document_blockmask_all + partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all) + full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) + def build_bm(window_size_blocks: Tensor) -> BlockMask: + return BlockMask.from_kv_blocks( + torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), + partial_kv_indices, + torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), + full_kv_indices, + BLOCK_SIZE=BLOCK_SIZE, + mask_mod=document_causal, + ) + # Long-short SWA block masks by @leloykun & @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper + return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) + + def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): + assert input_seq.ndim == 1 + + ve = [value_embed(input_seq) for value_embed in self.value_embeds] + # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun's U-net structure + ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] + assert len(ve) == len(self.blocks) + + long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) + block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] + assert len(block_masks) == len(self.blocks) + + x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 + + # U-net design by @brendanh0gan + skip_connections = [] + skip_weights = self.scalars[:(len(self.blocks) // 2)] + lambdas = self.scalars[1 * len(self.blocks): 3 * len(self.blocks)].view(-1, 2) + sa_lambdas = self.scalars[3 * len(self.blocks): 5 * len(self.blocks)].view(-1, 2) + + n = len(self.blocks) // 2 + + for i in range(len(self.blocks)): + if i >= n: + x = x + skip_weights[i - n] * skip_connections.pop() + x = self.blocks[i](x, ve[i], x0, lambdas[i], sa_lambdas[i], block_masks[i]) + if i < n: + skip_connections.append(x) + + x = norm(x) + logits = self.lm_head(x).float() + # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) + logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5)) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction="sum" if self.training else "mean") + return loss + +# ----------------------------------------------------------------------------- +# Distributed data loader + +def _load_data_shard(file: Path): + header = torch.from_file(str(file), False, 256, dtype=torch.int32) # header is 256 int32 + assert header[0] == 20240520, "magic number mismatch in the data .bin file" + assert header[1] == 1, "unsupported version" + num_tokens = int(header[2]) # number of tokens (claimed) + with file.open("rb", buffering=0) as f: + tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True) # avoid pin_memory copy by @YouJiacheng + f.seek(256 * 4) + nbytes = f.readinto(tokens.numpy()) # avoid bytes->array copy by @YouJiacheng + assert nbytes == 2 * num_tokens, "number of tokens read does not match header" + return tokens + +# find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 + boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos + start = boundary_positions[0].item() + starts = [] + for i in range(1, len(boundary_positions)): + end = boundary_positions[i].item() + if end - start >= seq_len: + starts.append(start) # append start once end pos is confirmed + if len(starts) == dist.get_world_size(): + return starts, end - pos + start = end + assert False # increase token_window if necessary + +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): + rank = dist.get_rank() + world_size = dist.get_world_size() + batch_size = seq_len * world_size + files = [Path(file) for file in sorted(glob.glob(filename_pattern))] + file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training + tokens, pos = _load_data_shard(next(file_iter)), 0 + while True: + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets + +# ----------------------------------------------------------------------------- +# int main + +@dataclass +class Hyperparameters: + # data + train_files = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on + val_files = "data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on + val_tokens = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons + train_seq_len = 48*1024 # FlexAttention sequence length + val_seq_len = 4*64*1024 # FlexAttention sequence length for validation + # optimization + num_iterations = 1750 # number of iterations to run + cooldown_frac = 0.45 # fraction of training spent cooling down the learning rate + # evaluation and logging + val_loss_every = 125 # every how many steps to evaluate val loss? 0 for only at the end + save_checkpoint = False +args = Hyperparameters() + +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + +# torchrun sets these env variables +rank = int(os.environ["RANK"]) +world_size = int(os.environ["WORLD_SIZE"]) +#assert world_size == 8 # this code is designed for 8xH100 +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size +assert torch.cuda.is_available() +device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) +torch.cuda.set_device(device) +dist.init_process_group(backend="nccl", device_id=device) +dist.barrier() +master_process = (rank == 0) # this process will do logging, checkpointing etc. + +# begin logging +logfile = None +if master_process: + run_id = uuid.uuid4() + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{run_id}.txt" + print(logfile) +def print0(s, console=False): + if master_process: + with open(logfile, "a") as f: + if console: + print(s) + print(s, file=f) + +# begin by printing this file (the Python code) +print0(code) +print0("="*100) +# log information about the hardware/software environment this is running on +print0(f"Running Python {sys.version}") +print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") +def nvidia_smi(): + import subprocess # avoid top level import + return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout +print0(nvidia_smi()) +print0("="*100) + +model: nn.Module = GPT(vocab_size=50257, num_layers=12, num_heads=6, model_dim=768, max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda() +for m in model.modules(): + if isinstance(m, nn.Embedding): + m.bfloat16() +for param in model.parameters(): + dist.broadcast(param.detach(), 0) + +# collect the parameters to optimize +hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n] +embed_params = [p for n, p in model.named_parameters() if "embed" in n] +scalar_params = [p for p in model.parameters() if p.ndim < 2] +head_params = [model.lm_head.weight] + +# init the optimizer(s) +# small adam epsilon by @YouJiacheng. this is an alternate method of fixing the world_size dependence +# discovered by @fernbear.bsky.social https://x.com/hi_tysam/status/1879692937589875094 +optimizer1 = DistAdam(scalar_params + head_params + embed_params, lr=0.008, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0) +optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, weight_decay=0.0) +optimizers = [optimizer1, optimizer2] +for opt in optimizers: + for group in opt.param_groups: + group["initial_lr"] = group["lr"] + +# learning rate schedule: stable then decay +def get_lr(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x < 1 + if x < 1 - args.cooldown_frac: + return 1.0 + else: + w = (1 - x) / args.cooldown_frac + return w * 1.0 + (1 - w) * 0.1 + +# attention window size schedule: linearly increase +@lru_cache(1) +def get_window_size_blocks_helper(window_size: int): + return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) +def get_window_size_blocks(step: int): + x = step / args.num_iterations # progress in training + assert 0 <= x <= 1 + # Linearly increase the block-wise sliding window size over training 128 -> 1792 + # increase by @fernbear.bsky.social; block-wise by @YouJiacheng + window_size = next_multiple_of_n(1728 * x, n=128) + return get_window_size_blocks_helper(window_size) + +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) + +######################################## +# Warmup kernels # +######################################## + +# Warmup the training kernels, then re-initialize the state so we aren't cheating +warmup_steps = 10 +initial_state = dict(model=copy.deepcopy(model.state_dict()), + optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +for _ in range(warmup_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(1)).backward() + for opt in optimizers: + opt.step() + model.zero_grad(set_to_none=True) +model.load_state_dict(initial_state["model"]) +for opt, opt_state in zip(optimizers, initial_state["optimizers"]): + opt.load_state_dict(opt_state) +del train_loader, initial_state + +######################################## +# Training and validation # +######################################## + +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) +training_time_ms = 0 +# start the clock +torch.cuda.synchronize() +t0 = time.perf_counter() +# begin training +train_steps = args.num_iterations +for step in range(train_steps + 1): + last_step = (step == train_steps) + + # --------------- VALIDATION SECTION ----------------- + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # stop the clock + torch.cuda.synchronize() + training_time_ms += 1000 * (time.perf_counter() - t0) + model.eval() + val_batch_size = world_size * args.val_seq_len + assert args.val_tokens % val_batch_size == 0 + val_steps = args.val_tokens // val_batch_size + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) + val_loss = 0 + with torch.no_grad(): + for _ in range(val_steps): + inputs, targets = next(val_loader) + val_loss += model(inputs, targets, get_window_size_blocks(step)) + val_loss /= val_steps + del val_loader + dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) + print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step, 1):.2f}ms", console=True) + model.train() + # start the clock again + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if master_process and args.save_checkpoint: + log = dict(step=step, code=code, model=model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers]) + os.makedirs(f"logs/{run_id}", exist_ok=True) + torch.save(log, f"logs/{run_id}/state_step{step:06d}.pt") + # the last step only has the validation loop, so break to avoid training + break + + # --------------- TRAINING SECTION ----------------- + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() + # set optimization hyperparameters + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["initial_lr"] * get_lr(step) + for group in optimizer2.param_groups: + frac = min(step / 300, 1) # momentum warmup for muon + group["momentum"] = (1 - frac) * 0.85 + frac * 0.95 + # step the optimizers + for opt in optimizers: + opt.step() + # null the gradients + model.zero_grad(set_to_none=True) + # logging + approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0) + print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms/(step + 1):.2f}ms", console=True) + +print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) +dist.destroy_process_group() + +==================================================================================================== +Running Python 3.10.12 (main, May 27 2025, 17:12:29) [GCC 11.4.0] +Running PyTorch 2.9.0.dev20250718+cu126 compiled for CUDA 12.6 +Running Triton version 3.4.0 +Fri Jul 18 15:57:50 2025 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.148.08 Driver Version: 570.148.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:61:00.0 Off | 0 | +| N/A 30C P0 132W / 700W | 5858MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:62:00.0 Off | 0 | +| N/A 32C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:63:00.0 Off | 0 | +| N/A 32C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:64:00.0 Off | 0 | +| N/A 29C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:6A:00.0 Off | 0 | +| N/A 29C P0 123W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:6B:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:6C:00.0 Off | 0 | +| N/A 32C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:6D:00.0 Off | 0 | +| N/A 28C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 111824 C /usr/bin/python3 1510MiB | +| 0 N/A N/A 111825 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111826 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111827 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111828 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111829 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111830 C /usr/bin/python3 614MiB | +| 0 N/A N/A 111831 C /usr/bin/python3 614MiB | +| 1 N/A N/A 111825 C /usr/bin/python3 1510MiB | +| 2 N/A N/A 111826 C /usr/bin/python3 1510MiB | +| 3 N/A N/A 111827 C /usr/bin/python3 1510MiB | +| 4 N/A N/A 111828 C /usr/bin/python3 1510MiB | +| 5 N/A N/A 111829 C /usr/bin/python3 1510MiB | +| 6 N/A N/A 111830 C /usr/bin/python3 1510MiB | +| 7 N/A N/A 111831 C /usr/bin/python3 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +step:0/1750 val_loss:10.8258 train_time:0ms step_avg:0.04ms +step:1/1750 train_time:151ms step_avg:151.05ms +step:2/1750 train_time:176ms step_avg:87.79ms +step:3/1750 train_time:240ms step_avg:79.89ms +step:4/1750 train_time:329ms step_avg:82.25ms +step:5/1750 train_time:420ms step_avg:83.91ms +step:6/1750 train_time:510ms step_avg:84.94ms +step:7/1750 train_time:601ms step_avg:85.79ms +step:8/1750 train_time:691ms step_avg:86.32ms +step:9/1750 train_time:781ms step_avg:86.76ms +step:10/1750 train_time:871ms step_avg:87.08ms +step:11/1750 train_time:961ms step_avg:87.37ms +step:12/1750 train_time:1053ms step_avg:87.78ms +step:13/1750 train_time:1148ms step_avg:88.27ms +step:14/1750 train_time:1242ms step_avg:88.73ms +step:15/1750 train_time:1334ms step_avg:88.93ms +step:16/1750 train_time:1425ms step_avg:89.09ms +step:17/1750 train_time:1517ms step_avg:89.22ms +step:18/1750 train_time:1608ms step_avg:89.31ms +step:19/1750 train_time:1699ms step_avg:89.40ms +step:20/1750 train_time:1790ms step_avg:89.49ms +step:21/1750 train_time:1880ms step_avg:89.52ms +step:22/1750 train_time:1970ms step_avg:89.56ms +step:23/1750 train_time:2062ms step_avg:89.65ms +step:24/1750 train_time:2156ms step_avg:89.82ms +step:25/1750 train_time:2248ms step_avg:89.92ms +step:26/1750 train_time:2341ms step_avg:90.03ms +step:27/1750 train_time:2432ms step_avg:90.07ms +step:28/1750 train_time:2523ms step_avg:90.10ms +step:29/1750 train_time:2613ms step_avg:90.12ms +step:30/1750 train_time:2705ms step_avg:90.17ms +step:31/1750 train_time:2796ms step_avg:90.19ms +step:32/1750 train_time:2887ms step_avg:90.21ms +step:33/1750 train_time:2978ms step_avg:90.24ms +step:34/1750 train_time:3069ms step_avg:90.27ms +step:35/1750 train_time:3162ms step_avg:90.33ms +step:36/1750 train_time:3254ms step_avg:90.38ms +step:37/1750 train_time:3346ms step_avg:90.42ms +step:38/1750 train_time:3438ms step_avg:90.47ms +step:39/1750 train_time:3529ms step_avg:90.49ms +step:40/1750 train_time:3621ms step_avg:90.52ms +step:41/1750 train_time:3712ms step_avg:90.55ms +step:42/1750 train_time:3803ms step_avg:90.55ms +step:43/1750 train_time:3894ms step_avg:90.56ms +step:44/1750 train_time:3985ms step_avg:90.57ms +step:45/1750 train_time:4077ms step_avg:90.60ms +step:46/1750 train_time:4169ms step_avg:90.63ms +step:47/1750 train_time:4261ms step_avg:90.65ms +step:48/1750 train_time:4352ms step_avg:90.66ms +step:49/1750 train_time:4444ms step_avg:90.70ms +step:50/1750 train_time:4535ms step_avg:90.70ms +step:51/1750 train_time:4628ms step_avg:90.74ms +step:52/1750 train_time:4719ms step_avg:90.75ms +step:53/1750 train_time:4810ms step_avg:90.76ms +step:54/1750 train_time:4902ms step_avg:90.77ms +step:55/1750 train_time:4993ms step_avg:90.78ms +step:56/1750 train_time:5085ms step_avg:90.80ms +step:57/1750 train_time:5176ms step_avg:90.80ms +step:58/1750 train_time:5267ms step_avg:90.82ms +step:59/1750 train_time:5359ms step_avg:90.84ms +step:60/1750 train_time:5451ms step_avg:90.85ms +step:61/1750 train_time:5543ms step_avg:90.86ms +step:62/1750 train_time:5634ms step_avg:90.88ms +step:63/1750 train_time:5726ms step_avg:90.89ms +step:64/1750 train_time:5819ms step_avg:90.91ms +step:65/1750 train_time:5910ms step_avg:90.93ms +step:66/1750 train_time:6002ms step_avg:90.94ms +step:67/1750 train_time:6093ms step_avg:90.94ms +step:68/1750 train_time:6185ms step_avg:90.95ms +step:69/1750 train_time:6276ms step_avg:90.96ms +step:70/1750 train_time:6368ms step_avg:90.97ms +step:71/1750 train_time:6459ms step_avg:90.98ms +step:72/1750 train_time:6551ms step_avg:90.99ms +step:73/1750 train_time:6643ms step_avg:91.00ms +step:74/1750 train_time:6734ms step_avg:91.00ms +step:75/1750 train_time:6826ms step_avg:91.02ms +step:76/1750 train_time:6918ms step_avg:91.02ms +step:77/1750 train_time:7009ms step_avg:91.03ms +step:78/1750 train_time:7101ms step_avg:91.04ms +step:79/1750 train_time:7192ms step_avg:91.04ms +step:80/1750 train_time:7284ms step_avg:91.05ms +step:81/1750 train_time:7375ms step_avg:91.05ms +step:82/1750 train_time:7466ms step_avg:91.05ms +step:83/1750 train_time:7558ms step_avg:91.06ms +step:84/1750 train_time:7650ms step_avg:91.07ms +step:85/1750 train_time:7742ms step_avg:91.08ms +step:86/1750 train_time:7833ms step_avg:91.08ms +step:87/1750 train_time:7925ms step_avg:91.09ms +step:88/1750 train_time:8016ms step_avg:91.09ms +step:89/1750 train_time:8108ms step_avg:91.10ms +step:90/1750 train_time:8200ms step_avg:91.11ms +step:91/1750 train_time:8290ms step_avg:91.10ms +step:92/1750 train_time:8382ms step_avg:91.11ms +step:93/1750 train_time:8474ms step_avg:91.12ms +step:94/1750 train_time:8566ms step_avg:91.12ms +step:95/1750 train_time:8657ms step_avg:91.13ms +step:96/1750 train_time:8749ms step_avg:91.14ms +step:97/1750 train_time:8841ms step_avg:91.14ms +step:98/1750 train_time:8931ms step_avg:91.14ms +step:99/1750 train_time:9023ms step_avg:91.14ms +step:100/1750 train_time:9114ms step_avg:91.14ms +step:101/1750 train_time:9207ms step_avg:91.15ms +step:102/1750 train_time:9298ms step_avg:91.15ms +step:103/1750 train_time:9389ms step_avg:91.16ms +step:104/1750 train_time:9481ms step_avg:91.16ms +step:105/1750 train_time:9573ms step_avg:91.17ms +step:106/1750 train_time:9665ms step_avg:91.17ms +step:107/1750 train_time:9756ms step_avg:91.18ms +step:108/1750 train_time:9847ms step_avg:91.18ms +step:109/1750 train_time:9939ms step_avg:91.18ms +step:110/1750 train_time:10031ms step_avg:91.19ms +step:111/1750 train_time:10123ms step_avg:91.20ms +step:112/1750 train_time:10215ms step_avg:91.20ms +step:113/1750 train_time:10307ms step_avg:91.21ms +step:114/1750 train_time:10398ms step_avg:91.21ms +step:115/1750 train_time:10490ms step_avg:91.22ms +step:116/1750 train_time:10581ms step_avg:91.22ms +step:117/1750 train_time:10673ms step_avg:91.22ms +step:118/1750 train_time:10765ms step_avg:91.23ms +step:119/1750 train_time:10856ms step_avg:91.23ms +step:120/1750 train_time:10947ms step_avg:91.23ms +step:121/1750 train_time:11039ms step_avg:91.23ms +step:122/1750 train_time:11131ms step_avg:91.24ms +step:123/1750 train_time:11224ms step_avg:91.25ms +step:124/1750 train_time:11315ms step_avg:91.25ms +step:125/1750 train_time:11407ms step_avg:91.26ms +step:125/1750 val_loss:4.6317 train_time:11502ms step_avg:92.02ms +step:126/1750 train_time:11528ms step_avg:91.49ms +step:127/1750 train_time:11599ms step_avg:91.33ms +step:128/1750 train_time:11696ms step_avg:91.38ms +step:129/1750 train_time:11790ms step_avg:91.39ms +step:130/1750 train_time:11881ms step_avg:91.39ms +step:131/1750 train_time:11972ms step_avg:91.39ms +step:132/1750 train_time:12062ms step_avg:91.38ms +step:133/1750 train_time:12153ms step_avg:91.38ms +step:134/1750 train_time:12244ms step_avg:91.38ms +step:135/1750 train_time:12336ms step_avg:91.37ms +step:136/1750 train_time:12427ms step_avg:91.37ms +step:137/1750 train_time:12520ms step_avg:91.39ms +step:138/1750 train_time:12613ms step_avg:91.40ms +step:139/1750 train_time:12708ms step_avg:91.42ms +step:140/1750 train_time:12802ms step_avg:91.44ms +step:141/1750 train_time:12894ms step_avg:91.45ms +step:142/1750 train_time:12986ms step_avg:91.45ms +step:143/1750 train_time:13078ms step_avg:91.45ms +step:144/1750 train_time:13169ms step_avg:91.45ms +step:145/1750 train_time:13260ms step_avg:91.45ms +step:146/1750 train_time:13351ms step_avg:91.44ms +step:147/1750 train_time:13442ms step_avg:91.44ms +step:148/1750 train_time:13534ms step_avg:91.45ms +step:149/1750 train_time:13627ms step_avg:91.46ms +step:150/1750 train_time:13721ms step_avg:91.48ms +step:151/1750 train_time:13813ms step_avg:91.48ms +step:152/1750 train_time:13906ms step_avg:91.49ms +step:153/1750 train_time:13998ms step_avg:91.49ms +step:154/1750 train_time:14089ms step_avg:91.49ms +step:155/1750 train_time:14180ms step_avg:91.49ms +step:156/1750 train_time:14272ms step_avg:91.48ms +step:157/1750 train_time:14363ms step_avg:91.48ms +step:158/1750 train_time:14455ms step_avg:91.49ms +step:159/1750 train_time:14547ms step_avg:91.49ms +step:160/1750 train_time:14640ms step_avg:91.50ms +step:161/1750 train_time:14732ms step_avg:91.50ms +step:162/1750 train_time:14825ms step_avg:91.51ms +step:163/1750 train_time:14917ms step_avg:91.51ms +step:164/1750 train_time:15008ms step_avg:91.51ms +step:165/1750 train_time:15101ms step_avg:91.52ms +step:166/1750 train_time:15193ms step_avg:91.52ms +step:167/1750 train_time:15284ms step_avg:91.52ms +step:168/1750 train_time:15376ms step_avg:91.52ms +step:169/1750 train_time:15468ms step_avg:91.53ms +step:170/1750 train_time:15561ms step_avg:91.54ms +step:171/1750 train_time:15653ms step_avg:91.54ms +step:172/1750 train_time:15746ms step_avg:91.55ms +step:173/1750 train_time:15838ms step_avg:91.55ms +step:174/1750 train_time:15929ms step_avg:91.55ms +step:175/1750 train_time:16021ms step_avg:91.55ms +step:176/1750 train_time:16112ms step_avg:91.55ms +step:177/1750 train_time:16204ms step_avg:91.55ms +step:178/1750 train_time:16296ms step_avg:91.55ms +step:179/1750 train_time:16387ms step_avg:91.55ms +step:180/1750 train_time:16480ms step_avg:91.56ms +step:181/1750 train_time:16572ms step_avg:91.56ms +step:182/1750 train_time:16665ms step_avg:91.57ms +step:183/1750 train_time:16757ms step_avg:91.57ms +step:184/1750 train_time:16850ms step_avg:91.58ms +step:185/1750 train_time:16943ms step_avg:91.58ms +step:186/1750 train_time:17034ms step_avg:91.58ms +step:187/1750 train_time:17126ms step_avg:91.58ms +step:188/1750 train_time:17218ms step_avg:91.58ms +step:189/1750 train_time:17309ms step_avg:91.58ms +step:190/1750 train_time:17402ms step_avg:91.59ms +step:191/1750 train_time:17494ms step_avg:91.59ms +step:192/1750 train_time:17587ms step_avg:91.60ms +step:193/1750 train_time:17678ms step_avg:91.60ms +step:194/1750 train_time:17770ms step_avg:91.60ms +step:195/1750 train_time:17862ms step_avg:91.60ms +step:196/1750 train_time:17954ms step_avg:91.60ms +step:197/1750 train_time:18046ms step_avg:91.60ms +step:198/1750 train_time:18137ms step_avg:91.60ms +step:199/1750 train_time:18229ms step_avg:91.61ms +step:200/1750 train_time:18321ms step_avg:91.60ms +step:201/1750 train_time:18412ms step_avg:91.60ms +step:202/1750 train_time:18504ms step_avg:91.61ms +step:203/1750 train_time:18597ms step_avg:91.61ms +step:204/1750 train_time:18689ms step_avg:91.61ms +step:205/1750 train_time:18781ms step_avg:91.61ms +step:206/1750 train_time:18874ms step_avg:91.62ms +step:207/1750 train_time:18966ms step_avg:91.62ms +step:208/1750 train_time:19057ms step_avg:91.62ms +step:209/1750 train_time:19149ms step_avg:91.62ms +step:210/1750 train_time:19241ms step_avg:91.62ms +step:211/1750 train_time:19332ms step_avg:91.62ms +step:212/1750 train_time:19425ms step_avg:91.63ms +step:213/1750 train_time:19517ms step_avg:91.63ms +step:214/1750 train_time:19609ms step_avg:91.63ms +step:215/1750 train_time:19702ms step_avg:91.63ms +step:216/1750 train_time:19794ms step_avg:91.64ms +step:217/1750 train_time:19886ms step_avg:91.64ms +step:218/1750 train_time:19978ms step_avg:91.64ms +step:219/1750 train_time:20070ms step_avg:91.65ms +step:220/1750 train_time:20162ms step_avg:91.65ms +step:221/1750 train_time:20254ms step_avg:91.65ms +step:222/1750 train_time:20346ms step_avg:91.65ms +step:223/1750 train_time:20438ms step_avg:91.65ms +step:224/1750 train_time:20530ms step_avg:91.65ms +step:225/1750 train_time:20622ms step_avg:91.65ms +step:226/1750 train_time:20714ms step_avg:91.65ms +step:227/1750 train_time:20806ms step_avg:91.66ms +step:228/1750 train_time:20899ms step_avg:91.66ms +step:229/1750 train_time:20990ms step_avg:91.66ms +step:230/1750 train_time:21082ms step_avg:91.66ms +step:231/1750 train_time:21174ms step_avg:91.66ms +step:232/1750 train_time:21267ms step_avg:91.67ms +step:233/1750 train_time:21360ms step_avg:91.67ms +step:234/1750 train_time:21450ms step_avg:91.67ms +step:235/1750 train_time:21542ms step_avg:91.67ms +step:236/1750 train_time:21633ms step_avg:91.67ms +step:237/1750 train_time:21726ms step_avg:91.67ms +step:238/1750 train_time:21819ms step_avg:91.67ms +step:239/1750 train_time:21910ms step_avg:91.68ms +step:240/1750 train_time:22002ms step_avg:91.68ms +step:241/1750 train_time:22094ms step_avg:91.68ms +step:242/1750 train_time:22186ms step_avg:91.68ms +step:243/1750 train_time:22278ms step_avg:91.68ms +step:244/1750 train_time:22370ms step_avg:91.68ms +step:245/1750 train_time:22462ms step_avg:91.68ms +step:246/1750 train_time:22554ms step_avg:91.68ms +step:247/1750 train_time:22646ms step_avg:91.68ms +step:248/1750 train_time:22738ms step_avg:91.69ms +step:249/1750 train_time:22830ms step_avg:91.69ms +step:250/1750 train_time:22923ms step_avg:91.69ms +step:250/1750 val_loss:4.0972 train_time:23018ms step_avg:92.07ms +step:251/1750 train_time:23041ms step_avg:91.80ms +step:252/1750 train_time:23113ms step_avg:91.72ms +step:253/1750 train_time:23211ms step_avg:91.74ms +step:254/1750 train_time:23303ms step_avg:91.75ms +step:255/1750 train_time:23395ms step_avg:91.74ms +step:256/1750 train_time:23486ms step_avg:91.74ms +step:257/1750 train_time:23577ms step_avg:91.74ms +step:258/1750 train_time:23668ms step_avg:91.73ms +step:259/1750 train_time:23759ms step_avg:91.73ms +step:260/1750 train_time:23849ms step_avg:91.73ms +step:261/1750 train_time:23940ms step_avg:91.73ms +step:262/1750 train_time:24034ms step_avg:91.73ms +step:263/1750 train_time:24128ms step_avg:91.74ms +step:264/1750 train_time:24221ms step_avg:91.75ms +step:265/1750 train_time:24314ms step_avg:91.75ms +step:266/1750 train_time:24406ms step_avg:91.75ms +step:267/1750 train_time:24499ms step_avg:91.75ms +step:268/1750 train_time:24591ms step_avg:91.76ms +step:269/1750 train_time:24683ms step_avg:91.76ms +step:270/1750 train_time:24775ms step_avg:91.76ms +step:271/1750 train_time:24867ms step_avg:91.76ms +step:272/1750 train_time:24959ms step_avg:91.76ms +step:273/1750 train_time:25051ms step_avg:91.76ms +step:274/1750 train_time:25144ms step_avg:91.77ms +step:275/1750 train_time:25237ms step_avg:91.77ms +step:276/1750 train_time:25330ms step_avg:91.78ms +step:277/1750 train_time:25423ms step_avg:91.78ms +step:278/1750 train_time:25515ms step_avg:91.78ms +step:279/1750 train_time:25607ms step_avg:91.78ms +step:280/1750 train_time:25699ms step_avg:91.78ms +step:281/1750 train_time:25792ms step_avg:91.79ms +step:282/1750 train_time:25884ms step_avg:91.79ms +step:283/1750 train_time:25976ms step_avg:91.79ms +step:284/1750 train_time:26068ms step_avg:91.79ms +step:285/1750 train_time:26161ms step_avg:91.79ms +step:286/1750 train_time:26254ms step_avg:91.80ms +step:287/1750 train_time:26346ms step_avg:91.80ms +step:288/1750 train_time:26439ms step_avg:91.80ms +step:289/1750 train_time:26532ms step_avg:91.81ms +step:290/1750 train_time:26624ms step_avg:91.81ms +step:291/1750 train_time:26716ms step_avg:91.81ms +step:292/1750 train_time:26808ms step_avg:91.81ms +step:293/1750 train_time:26900ms step_avg:91.81ms +step:294/1750 train_time:26992ms step_avg:91.81ms +step:295/1750 train_time:27085ms step_avg:91.81ms +step:296/1750 train_time:27177ms step_avg:91.82ms +step:297/1750 train_time:27269ms step_avg:91.82ms +step:298/1750 train_time:27362ms step_avg:91.82ms +step:299/1750 train_time:27455ms step_avg:91.82ms +step:300/1750 train_time:27548ms step_avg:91.83ms +step:301/1750 train_time:27640ms step_avg:91.83ms +step:302/1750 train_time:27732ms step_avg:91.83ms +step:303/1750 train_time:27824ms step_avg:91.83ms +step:304/1750 train_time:27916ms step_avg:91.83ms +step:305/1750 train_time:28008ms step_avg:91.83ms +step:306/1750 train_time:28101ms step_avg:91.83ms +step:307/1750 train_time:28193ms step_avg:91.83ms +step:308/1750 train_time:28285ms step_avg:91.84ms +step:309/1750 train_time:28378ms step_avg:91.84ms +step:310/1750 train_time:28470ms step_avg:91.84ms +step:311/1750 train_time:28563ms step_avg:91.84ms +step:312/1750 train_time:28656ms step_avg:91.85ms +step:313/1750 train_time:28747ms step_avg:91.84ms +step:314/1750 train_time:28839ms step_avg:91.84ms +step:315/1750 train_time:28932ms step_avg:91.85ms +step:316/1750 train_time:29024ms step_avg:91.85ms +step:317/1750 train_time:29117ms step_avg:91.85ms +step:318/1750 train_time:29209ms step_avg:91.85ms +step:319/1750 train_time:29303ms step_avg:91.86ms +step:320/1750 train_time:29395ms step_avg:91.86ms +step:321/1750 train_time:29488ms step_avg:91.86ms +step:322/1750 train_time:29580ms step_avg:91.86ms +step:323/1750 train_time:29672ms step_avg:91.86ms +step:324/1750 train_time:29764ms step_avg:91.86ms +step:325/1750 train_time:29856ms step_avg:91.86ms +step:326/1750 train_time:29948ms step_avg:91.86ms +step:327/1750 train_time:30040ms step_avg:91.87ms +step:328/1750 train_time:30133ms step_avg:91.87ms +step:329/1750 train_time:30226ms step_avg:91.87ms +step:330/1750 train_time:30318ms step_avg:91.87ms +step:331/1750 train_time:30411ms step_avg:91.88ms +step:332/1750 train_time:30504ms step_avg:91.88ms +step:333/1750 train_time:30597ms step_avg:91.88ms +step:334/1750 train_time:30690ms step_avg:91.89ms +step:335/1750 train_time:30782ms step_avg:91.89ms +step:336/1750 train_time:30874ms step_avg:91.89ms +step:337/1750 train_time:30966ms step_avg:91.89ms +step:338/1750 train_time:31059ms step_avg:91.89ms +step:339/1750 train_time:31152ms step_avg:91.89ms +step:340/1750 train_time:31244ms step_avg:91.90ms +step:341/1750 train_time:31338ms step_avg:91.90ms +step:342/1750 train_time:31431ms step_avg:91.90ms +step:343/1750 train_time:31523ms step_avg:91.90ms +step:344/1750 train_time:31615ms step_avg:91.90ms +step:345/1750 train_time:31707ms step_avg:91.90ms +step:346/1750 train_time:31800ms step_avg:91.91ms +step:347/1750 train_time:31892ms step_avg:91.91ms +step:348/1750 train_time:31984ms step_avg:91.91ms +step:349/1750 train_time:32077ms step_avg:91.91ms +step:350/1750 train_time:32168ms step_avg:91.91ms +step:351/1750 train_time:32261ms step_avg:91.91ms +step:352/1750 train_time:32353ms step_avg:91.91ms +step:353/1750 train_time:32445ms step_avg:91.91ms +step:354/1750 train_time:32539ms step_avg:91.92ms +step:355/1750 train_time:32631ms step_avg:91.92ms +step:356/1750 train_time:32724ms step_avg:91.92ms +step:357/1750 train_time:32816ms step_avg:91.92ms +step:358/1750 train_time:32908ms step_avg:91.92ms +step:359/1750 train_time:33000ms step_avg:91.92ms +step:360/1750 train_time:33092ms step_avg:91.92ms +step:361/1750 train_time:33185ms step_avg:91.92ms +step:362/1750 train_time:33277ms step_avg:91.93ms +step:363/1750 train_time:33369ms step_avg:91.93ms +step:364/1750 train_time:33461ms step_avg:91.93ms +step:365/1750 train_time:33554ms step_avg:91.93ms +step:366/1750 train_time:33646ms step_avg:91.93ms +step:367/1750 train_time:33738ms step_avg:91.93ms +step:368/1750 train_time:33830ms step_avg:91.93ms +step:369/1750 train_time:33923ms step_avg:91.93ms +step:370/1750 train_time:34015ms step_avg:91.93ms +step:371/1750 train_time:34107ms step_avg:91.93ms +step:372/1750 train_time:34200ms step_avg:91.93ms +step:373/1750 train_time:34292ms step_avg:91.94ms +step:374/1750 train_time:34385ms step_avg:91.94ms +step:375/1750 train_time:34477ms step_avg:91.94ms +step:375/1750 val_loss:3.8911 train_time:34573ms step_avg:92.19ms +step:376/1750 train_time:34597ms step_avg:92.01ms +step:377/1750 train_time:34671ms step_avg:91.97ms +step:378/1750 train_time:34767ms step_avg:91.98ms +step:379/1750 train_time:34861ms step_avg:91.98ms +step:380/1750 train_time:34953ms step_avg:91.98ms +step:381/1750 train_time:35045ms step_avg:91.98ms +step:382/1750 train_time:35136ms step_avg:91.98ms +step:383/1750 train_time:35228ms step_avg:91.98ms +step:384/1750 train_time:35320ms step_avg:91.98ms +step:385/1750 train_time:35411ms step_avg:91.98ms +step:386/1750 train_time:35503ms step_avg:91.98ms +step:387/1750 train_time:35597ms step_avg:91.98ms +step:388/1750 train_time:35691ms step_avg:91.99ms +step:389/1750 train_time:35786ms step_avg:91.99ms +step:390/1750 train_time:35878ms step_avg:92.00ms +step:391/1750 train_time:35973ms step_avg:92.00ms +step:392/1750 train_time:36067ms step_avg:92.01ms +step:393/1750 train_time:36160ms step_avg:92.01ms +step:394/1750 train_time:36253ms step_avg:92.01ms +step:395/1750 train_time:36346ms step_avg:92.01ms +step:396/1750 train_time:36439ms step_avg:92.02ms +step:397/1750 train_time:36533ms step_avg:92.02ms +step:398/1750 train_time:36628ms step_avg:92.03ms +step:399/1750 train_time:36723ms step_avg:92.04ms +step:400/1750 train_time:36819ms step_avg:92.05ms +step:401/1750 train_time:36913ms step_avg:92.05ms +step:402/1750 train_time:37007ms step_avg:92.06ms +step:403/1750 train_time:37101ms step_avg:92.06ms +step:404/1750 train_time:37195ms step_avg:92.07ms +step:405/1750 train_time:37288ms step_avg:92.07ms +step:406/1750 train_time:37381ms step_avg:92.07ms +step:407/1750 train_time:37476ms step_avg:92.08ms +step:408/1750 train_time:37570ms step_avg:92.08ms +step:409/1750 train_time:37665ms step_avg:92.09ms +step:410/1750 train_time:37760ms step_avg:92.10ms +step:411/1750 train_time:37855ms step_avg:92.10ms +step:412/1750 train_time:37949ms step_avg:92.11ms +step:413/1750 train_time:38044ms step_avg:92.12ms +step:414/1750 train_time:38138ms step_avg:92.12ms +step:415/1750 train_time:38232ms step_avg:92.12ms +step:416/1750 train_time:38325ms step_avg:92.13ms +step:417/1750 train_time:38419ms step_avg:92.13ms +step:418/1750 train_time:38514ms step_avg:92.14ms +step:419/1750 train_time:38608ms step_avg:92.14ms +step:420/1750 train_time:38703ms step_avg:92.15ms +step:421/1750 train_time:38798ms step_avg:92.16ms +step:422/1750 train_time:38893ms step_avg:92.16ms +step:423/1750 train_time:38987ms step_avg:92.17ms +step:424/1750 train_time:39081ms step_avg:92.17ms +step:425/1750 train_time:39176ms step_avg:92.18ms +step:426/1750 train_time:39269ms step_avg:92.18ms +step:427/1750 train_time:39363ms step_avg:92.18ms +step:428/1750 train_time:39458ms step_avg:92.19ms +step:429/1750 train_time:39552ms step_avg:92.20ms +step:430/1750 train_time:39646ms step_avg:92.20ms +step:431/1750 train_time:39741ms step_avg:92.21ms +step:432/1750 train_time:39837ms step_avg:92.21ms +step:433/1750 train_time:39931ms step_avg:92.22ms +step:434/1750 train_time:40026ms step_avg:92.23ms +step:435/1750 train_time:40120ms step_avg:92.23ms +step:436/1750 train_time:40214ms step_avg:92.23ms +step:437/1750 train_time:40308ms step_avg:92.24ms +step:438/1750 train_time:40402ms step_avg:92.24ms +step:439/1750 train_time:40497ms step_avg:92.25ms +step:440/1750 train_time:40591ms step_avg:92.25ms +step:441/1750 train_time:40685ms step_avg:92.26ms +step:442/1750 train_time:40780ms step_avg:92.26ms +step:443/1750 train_time:40876ms step_avg:92.27ms +step:444/1750 train_time:40971ms step_avg:92.28ms +step:445/1750 train_time:41065ms step_avg:92.28ms +step:446/1750 train_time:41159ms step_avg:92.29ms +step:447/1750 train_time:41254ms step_avg:92.29ms +step:448/1750 train_time:41348ms step_avg:92.29ms +step:449/1750 train_time:41442ms step_avg:92.30ms +step:450/1750 train_time:41536ms step_avg:92.30ms +step:451/1750 train_time:41630ms step_avg:92.31ms +step:452/1750 train_time:41724ms step_avg:92.31ms +step:453/1750 train_time:41819ms step_avg:92.32ms +step:454/1750 train_time:41914ms step_avg:92.32ms +step:455/1750 train_time:42008ms step_avg:92.33ms +step:456/1750 train_time:42102ms step_avg:92.33ms +step:457/1750 train_time:42196ms step_avg:92.33ms +step:458/1750 train_time:42291ms step_avg:92.34ms +step:459/1750 train_time:42385ms step_avg:92.34ms +step:460/1750 train_time:42479ms step_avg:92.35ms +step:461/1750 train_time:42573ms step_avg:92.35ms +step:462/1750 train_time:42667ms step_avg:92.35ms +step:463/1750 train_time:42761ms step_avg:92.36ms +step:464/1750 train_time:42855ms step_avg:92.36ms +step:465/1750 train_time:42950ms step_avg:92.36ms +step:466/1750 train_time:43044ms step_avg:92.37ms +step:467/1750 train_time:43138ms step_avg:92.37ms +step:468/1750 train_time:43233ms step_avg:92.38ms +step:469/1750 train_time:43327ms step_avg:92.38ms +step:470/1750 train_time:43421ms step_avg:92.38ms +step:471/1750 train_time:43515ms step_avg:92.39ms +step:472/1750 train_time:43609ms step_avg:92.39ms +step:473/1750 train_time:43704ms step_avg:92.40ms +step:474/1750 train_time:43798ms step_avg:92.40ms +step:475/1750 train_time:43892ms step_avg:92.41ms +step:476/1750 train_time:43987ms step_avg:92.41ms +step:477/1750 train_time:44080ms step_avg:92.41ms +step:478/1750 train_time:44174ms step_avg:92.41ms +step:479/1750 train_time:44269ms step_avg:92.42ms +step:480/1750 train_time:44363ms step_avg:92.42ms +step:481/1750 train_time:44458ms step_avg:92.43ms +step:482/1750 train_time:44552ms step_avg:92.43ms +step:483/1750 train_time:44645ms step_avg:92.43ms +step:484/1750 train_time:44739ms step_avg:92.44ms +step:485/1750 train_time:44833ms step_avg:92.44ms +step:486/1750 train_time:44927ms step_avg:92.44ms +step:487/1750 train_time:45022ms step_avg:92.45ms +step:488/1750 train_time:45116ms step_avg:92.45ms +step:489/1750 train_time:45211ms step_avg:92.46ms +step:490/1750 train_time:45305ms step_avg:92.46ms +step:491/1750 train_time:45400ms step_avg:92.46ms +step:492/1750 train_time:45495ms step_avg:92.47ms +step:493/1750 train_time:45589ms step_avg:92.47ms +step:494/1750 train_time:45683ms step_avg:92.48ms +step:495/1750 train_time:45778ms step_avg:92.48ms +step:496/1750 train_time:45873ms step_avg:92.49ms +step:497/1750 train_time:45968ms step_avg:92.49ms +step:498/1750 train_time:46063ms step_avg:92.50ms +step:499/1750 train_time:46157ms step_avg:92.50ms +step:500/1750 train_time:46251ms step_avg:92.50ms +step:500/1750 val_loss:3.7412 train_time:46349ms step_avg:92.70ms +step:501/1750 train_time:46373ms step_avg:92.56ms +step:502/1750 train_time:46448ms step_avg:92.53ms +step:503/1750 train_time:46548ms step_avg:92.54ms +step:504/1750 train_time:46642ms step_avg:92.54ms +step:505/1750 train_time:46738ms step_avg:92.55ms +step:506/1750 train_time:46831ms step_avg:92.55ms +step:507/1750 train_time:46923ms step_avg:92.55ms +step:508/1750 train_time:47016ms step_avg:92.55ms +step:509/1750 train_time:47109ms step_avg:92.55ms +step:510/1750 train_time:47202ms step_avg:92.55ms +step:511/1750 train_time:47296ms step_avg:92.56ms +step:512/1750 train_time:47392ms step_avg:92.56ms +step:513/1750 train_time:47489ms step_avg:92.57ms +step:514/1750 train_time:47584ms step_avg:92.58ms +step:515/1750 train_time:47679ms step_avg:92.58ms +step:516/1750 train_time:47773ms step_avg:92.58ms +step:517/1750 train_time:47867ms step_avg:92.59ms +step:518/1750 train_time:47960ms step_avg:92.59ms +step:519/1750 train_time:48054ms step_avg:92.59ms +step:520/1750 train_time:48147ms step_avg:92.59ms +step:521/1750 train_time:48241ms step_avg:92.59ms +step:522/1750 train_time:48335ms step_avg:92.60ms +step:523/1750 train_time:48431ms step_avg:92.60ms +step:524/1750 train_time:48526ms step_avg:92.61ms +step:525/1750 train_time:48622ms step_avg:92.61ms +step:526/1750 train_time:48718ms step_avg:92.62ms +step:527/1750 train_time:48812ms step_avg:92.62ms +step:528/1750 train_time:48906ms step_avg:92.62ms +step:529/1750 train_time:49000ms step_avg:92.63ms +step:530/1750 train_time:49094ms step_avg:92.63ms +step:531/1750 train_time:49189ms step_avg:92.63ms +step:532/1750 train_time:49283ms step_avg:92.64ms +step:533/1750 train_time:49378ms step_avg:92.64ms +step:534/1750 train_time:49473ms step_avg:92.65ms +step:535/1750 train_time:49569ms step_avg:92.65ms +step:536/1750 train_time:49664ms step_avg:92.66ms +step:537/1750 train_time:49759ms step_avg:92.66ms +step:538/1750 train_time:49853ms step_avg:92.66ms +step:539/1750 train_time:49947ms step_avg:92.67ms +step:540/1750 train_time:50041ms step_avg:92.67ms +step:541/1750 train_time:50136ms step_avg:92.67ms +step:542/1750 train_time:50230ms step_avg:92.68ms +step:543/1750 train_time:50324ms step_avg:92.68ms +step:544/1750 train_time:50419ms step_avg:92.68ms +step:545/1750 train_time:50514ms step_avg:92.69ms +step:546/1750 train_time:50609ms step_avg:92.69ms +step:547/1750 train_time:50703ms step_avg:92.69ms +step:548/1750 train_time:50799ms step_avg:92.70ms +step:549/1750 train_time:50893ms step_avg:92.70ms +step:550/1750 train_time:50988ms step_avg:92.70ms +step:551/1750 train_time:51082ms step_avg:92.71ms +step:552/1750 train_time:51176ms step_avg:92.71ms +step:553/1750 train_time:51270ms step_avg:92.71ms +step:554/1750 train_time:51364ms step_avg:92.72ms +step:555/1750 train_time:51460ms step_avg:92.72ms +step:556/1750 train_time:51555ms step_avg:92.73ms +step:557/1750 train_time:51650ms step_avg:92.73ms +step:558/1750 train_time:51745ms step_avg:92.73ms +step:559/1750 train_time:51839ms step_avg:92.74ms +step:560/1750 train_time:51934ms step_avg:92.74ms +step:561/1750 train_time:52028ms step_avg:92.74ms +step:562/1750 train_time:52122ms step_avg:92.74ms +step:563/1750 train_time:52216ms step_avg:92.75ms +step:564/1750 train_time:52311ms step_avg:92.75ms +step:565/1750 train_time:52406ms step_avg:92.75ms +step:566/1750 train_time:52501ms step_avg:92.76ms +step:567/1750 train_time:52596ms step_avg:92.76ms +step:568/1750 train_time:52691ms step_avg:92.77ms +step:569/1750 train_time:52785ms step_avg:92.77ms +step:570/1750 train_time:52880ms step_avg:92.77ms +step:571/1750 train_time:52976ms step_avg:92.78ms +step:572/1750 train_time:53069ms step_avg:92.78ms +step:573/1750 train_time:53163ms step_avg:92.78ms +step:574/1750 train_time:53258ms step_avg:92.78ms +step:575/1750 train_time:53353ms step_avg:92.79ms +step:576/1750 train_time:53447ms step_avg:92.79ms +step:577/1750 train_time:53542ms step_avg:92.79ms +step:578/1750 train_time:53636ms step_avg:92.80ms +step:579/1750 train_time:53731ms step_avg:92.80ms +step:580/1750 train_time:53827ms step_avg:92.81ms +step:581/1750 train_time:53922ms step_avg:92.81ms +step:582/1750 train_time:54017ms step_avg:92.81ms +step:583/1750 train_time:54112ms step_avg:92.82ms +step:584/1750 train_time:54205ms step_avg:92.82ms +step:585/1750 train_time:54300ms step_avg:92.82ms +step:586/1750 train_time:54394ms step_avg:92.82ms +step:587/1750 train_time:54490ms step_avg:92.83ms +step:588/1750 train_time:54584ms step_avg:92.83ms +step:589/1750 train_time:54679ms step_avg:92.83ms +step:590/1750 train_time:54773ms step_avg:92.84ms +step:591/1750 train_time:54868ms step_avg:92.84ms +step:592/1750 train_time:54963ms step_avg:92.84ms +step:593/1750 train_time:55058ms step_avg:92.85ms +step:594/1750 train_time:55152ms step_avg:92.85ms +step:595/1750 train_time:55245ms step_avg:92.85ms +step:596/1750 train_time:55340ms step_avg:92.85ms +step:597/1750 train_time:55434ms step_avg:92.85ms +step:598/1750 train_time:55529ms step_avg:92.86ms +step:599/1750 train_time:55624ms step_avg:92.86ms +step:600/1750 train_time:55719ms step_avg:92.86ms +step:601/1750 train_time:55813ms step_avg:92.87ms +step:602/1750 train_time:55910ms step_avg:92.87ms +step:603/1750 train_time:56004ms step_avg:92.88ms +step:604/1750 train_time:56098ms step_avg:92.88ms +step:605/1750 train_time:56193ms step_avg:92.88ms +step:606/1750 train_time:56288ms step_avg:92.88ms +step:607/1750 train_time:56383ms step_avg:92.89ms +step:608/1750 train_time:56478ms step_avg:92.89ms +step:609/1750 train_time:56572ms step_avg:92.89ms +step:610/1750 train_time:56668ms step_avg:92.90ms +step:611/1750 train_time:56762ms step_avg:92.90ms +step:612/1750 train_time:56858ms step_avg:92.91ms +step:613/1750 train_time:56952ms step_avg:92.91ms +step:614/1750 train_time:57047ms step_avg:92.91ms +step:615/1750 train_time:57141ms step_avg:92.91ms +step:616/1750 train_time:57236ms step_avg:92.92ms +step:617/1750 train_time:57330ms step_avg:92.92ms +step:618/1750 train_time:57425ms step_avg:92.92ms +step:619/1750 train_time:57519ms step_avg:92.92ms +step:620/1750 train_time:57613ms step_avg:92.92ms +step:621/1750 train_time:57708ms step_avg:92.93ms +step:622/1750 train_time:57802ms step_avg:92.93ms +step:623/1750 train_time:57896ms step_avg:92.93ms +step:624/1750 train_time:57990ms step_avg:92.93ms +step:625/1750 train_time:58086ms step_avg:92.94ms +step:625/1750 val_loss:3.6537 train_time:58184ms step_avg:93.09ms +step:626/1750 train_time:58208ms step_avg:92.98ms +step:627/1750 train_time:58285ms step_avg:92.96ms +step:628/1750 train_time:58383ms step_avg:92.97ms +step:629/1750 train_time:58478ms step_avg:92.97ms +step:630/1750 train_time:58571ms step_avg:92.97ms +step:631/1750 train_time:58665ms step_avg:92.97ms +step:632/1750 train_time:58759ms step_avg:92.97ms +step:633/1750 train_time:58852ms step_avg:92.97ms +step:634/1750 train_time:58946ms step_avg:92.97ms +step:635/1750 train_time:59040ms step_avg:92.98ms +step:636/1750 train_time:59133ms step_avg:92.98ms +step:637/1750 train_time:59229ms step_avg:92.98ms +step:638/1750 train_time:59325ms step_avg:92.99ms +step:639/1750 train_time:59420ms step_avg:92.99ms +step:640/1750 train_time:59516ms step_avg:92.99ms +step:641/1750 train_time:59611ms step_avg:93.00ms +step:642/1750 train_time:59706ms step_avg:93.00ms +step:643/1750 train_time:59800ms step_avg:93.00ms +step:644/1750 train_time:59894ms step_avg:93.00ms +step:645/1750 train_time:59987ms step_avg:93.00ms +step:646/1750 train_time:60081ms step_avg:93.00ms +step:647/1750 train_time:60176ms step_avg:93.01ms +step:648/1750 train_time:60271ms step_avg:93.01ms +step:649/1750 train_time:60367ms step_avg:93.01ms +step:650/1750 train_time:60462ms step_avg:93.02ms +step:651/1750 train_time:60559ms step_avg:93.02ms +step:652/1750 train_time:60655ms step_avg:93.03ms +step:653/1750 train_time:60750ms step_avg:93.03ms +step:654/1750 train_time:60846ms step_avg:93.04ms +step:655/1750 train_time:60941ms step_avg:93.04ms +step:656/1750 train_time:61037ms step_avg:93.04ms +step:657/1750 train_time:61132ms step_avg:93.05ms +step:658/1750 train_time:61228ms step_avg:93.05ms +step:659/1750 train_time:61324ms step_avg:93.06ms +step:660/1750 train_time:61421ms step_avg:93.06ms +step:661/1750 train_time:61517ms step_avg:93.07ms +step:662/1750 train_time:61614ms step_avg:93.07ms +step:663/1750 train_time:61709ms step_avg:93.08ms +step:664/1750 train_time:61805ms step_avg:93.08ms +step:665/1750 train_time:61900ms step_avg:93.08ms +step:666/1750 train_time:61996ms step_avg:93.09ms +step:667/1750 train_time:62092ms step_avg:93.09ms +step:668/1750 train_time:62188ms step_avg:93.10ms +step:669/1750 train_time:62284ms step_avg:93.10ms +step:670/1750 train_time:62381ms step_avg:93.11ms +step:671/1750 train_time:62477ms step_avg:93.11ms +step:672/1750 train_time:62573ms step_avg:93.11ms +step:673/1750 train_time:62669ms step_avg:93.12ms +step:674/1750 train_time:62765ms step_avg:93.12ms +step:675/1750 train_time:62860ms step_avg:93.13ms +step:676/1750 train_time:62956ms step_avg:93.13ms +step:677/1750 train_time:63051ms step_avg:93.13ms +step:678/1750 train_time:63147ms step_avg:93.14ms +step:679/1750 train_time:63243ms step_avg:93.14ms +step:680/1750 train_time:63340ms step_avg:93.15ms +step:681/1750 train_time:63436ms step_avg:93.15ms +step:682/1750 train_time:63532ms step_avg:93.16ms +step:683/1750 train_time:63627ms step_avg:93.16ms +step:684/1750 train_time:63723ms step_avg:93.16ms +step:685/1750 train_time:63820ms step_avg:93.17ms +step:686/1750 train_time:63915ms step_avg:93.17ms +step:687/1750 train_time:64011ms step_avg:93.17ms +step:688/1750 train_time:64106ms step_avg:93.18ms +step:689/1750 train_time:64202ms step_avg:93.18ms +step:690/1750 train_time:64299ms step_avg:93.19ms +step:691/1750 train_time:64394ms step_avg:93.19ms +step:692/1750 train_time:64490ms step_avg:93.19ms +step:693/1750 train_time:64586ms step_avg:93.20ms +step:694/1750 train_time:64681ms step_avg:93.20ms +step:695/1750 train_time:64778ms step_avg:93.21ms +step:696/1750 train_time:64874ms step_avg:93.21ms +step:697/1750 train_time:64969ms step_avg:93.21ms +step:698/1750 train_time:65065ms step_avg:93.22ms +step:699/1750 train_time:65161ms step_avg:93.22ms +step:700/1750 train_time:65257ms step_avg:93.22ms +step:701/1750 train_time:65353ms step_avg:93.23ms +step:702/1750 train_time:65450ms step_avg:93.23ms +step:703/1750 train_time:65545ms step_avg:93.24ms +step:704/1750 train_time:65642ms step_avg:93.24ms +step:705/1750 train_time:65738ms step_avg:93.25ms +step:706/1750 train_time:65834ms step_avg:93.25ms +step:707/1750 train_time:65930ms step_avg:93.25ms +step:708/1750 train_time:66025ms step_avg:93.26ms +step:709/1750 train_time:66122ms step_avg:93.26ms +step:710/1750 train_time:66218ms step_avg:93.26ms +step:711/1750 train_time:66314ms step_avg:93.27ms +step:712/1750 train_time:66410ms step_avg:93.27ms +step:713/1750 train_time:66505ms step_avg:93.28ms +step:714/1750 train_time:66601ms step_avg:93.28ms +step:715/1750 train_time:66698ms step_avg:93.28ms +step:716/1750 train_time:66793ms step_avg:93.29ms +step:717/1750 train_time:66889ms step_avg:93.29ms +step:718/1750 train_time:66985ms step_avg:93.29ms +step:719/1750 train_time:67081ms step_avg:93.30ms +step:720/1750 train_time:67176ms step_avg:93.30ms +step:721/1750 train_time:67272ms step_avg:93.30ms +step:722/1750 train_time:67368ms step_avg:93.31ms +step:723/1750 train_time:67464ms step_avg:93.31ms +step:724/1750 train_time:67561ms step_avg:93.32ms +step:725/1750 train_time:67657ms step_avg:93.32ms +step:726/1750 train_time:67753ms step_avg:93.32ms +step:727/1750 train_time:67850ms step_avg:93.33ms +step:728/1750 train_time:67945ms step_avg:93.33ms +step:729/1750 train_time:68042ms step_avg:93.34ms +step:730/1750 train_time:68138ms step_avg:93.34ms +step:731/1750 train_time:68234ms step_avg:93.34ms +step:732/1750 train_time:68330ms step_avg:93.35ms +step:733/1750 train_time:68425ms step_avg:93.35ms +step:734/1750 train_time:68521ms step_avg:93.35ms +step:735/1750 train_time:68617ms step_avg:93.36ms +step:736/1750 train_time:68713ms step_avg:93.36ms +step:737/1750 train_time:68809ms step_avg:93.36ms +step:738/1750 train_time:68904ms step_avg:93.37ms +step:739/1750 train_time:69000ms step_avg:93.37ms +step:740/1750 train_time:69097ms step_avg:93.37ms +step:741/1750 train_time:69193ms step_avg:93.38ms +step:742/1750 train_time:69289ms step_avg:93.38ms +step:743/1750 train_time:69385ms step_avg:93.38ms +step:744/1750 train_time:69481ms step_avg:93.39ms +step:745/1750 train_time:69577ms step_avg:93.39ms +step:746/1750 train_time:69673ms step_avg:93.40ms +step:747/1750 train_time:69770ms step_avg:93.40ms +step:748/1750 train_time:69866ms step_avg:93.40ms +step:749/1750 train_time:69961ms step_avg:93.41ms +step:750/1750 train_time:70057ms step_avg:93.41ms +step:750/1750 val_loss:3.5904 train_time:70156ms step_avg:93.54ms +step:751/1750 train_time:70181ms step_avg:93.45ms +step:752/1750 train_time:70258ms step_avg:93.43ms +step:753/1750 train_time:70358ms step_avg:93.44ms +step:754/1750 train_time:70454ms step_avg:93.44ms +step:755/1750 train_time:70550ms step_avg:93.44ms +step:756/1750 train_time:70645ms step_avg:93.45ms +step:757/1750 train_time:70740ms step_avg:93.45ms +step:758/1750 train_time:70835ms step_avg:93.45ms +step:759/1750 train_time:70930ms step_avg:93.45ms +step:760/1750 train_time:71024ms step_avg:93.45ms +step:761/1750 train_time:71120ms step_avg:93.46ms +step:762/1750 train_time:71218ms step_avg:93.46ms +step:763/1750 train_time:71316ms step_avg:93.47ms +step:764/1750 train_time:71413ms step_avg:93.47ms +step:765/1750 train_time:71510ms step_avg:93.48ms +step:766/1750 train_time:71606ms step_avg:93.48ms +step:767/1750 train_time:71701ms step_avg:93.48ms +step:768/1750 train_time:71797ms step_avg:93.49ms +step:769/1750 train_time:71891ms step_avg:93.49ms +step:770/1750 train_time:71986ms step_avg:93.49ms +step:771/1750 train_time:72082ms step_avg:93.49ms +step:772/1750 train_time:72178ms step_avg:93.49ms +step:773/1750 train_time:72275ms step_avg:93.50ms +step:774/1750 train_time:72372ms step_avg:93.50ms +step:775/1750 train_time:72469ms step_avg:93.51ms +step:776/1750 train_time:72566ms step_avg:93.51ms +step:777/1750 train_time:72661ms step_avg:93.52ms +step:778/1750 train_time:72756ms step_avg:93.52ms +step:779/1750 train_time:72851ms step_avg:93.52ms +step:780/1750 train_time:72947ms step_avg:93.52ms +step:781/1750 train_time:73042ms step_avg:93.52ms +step:782/1750 train_time:73138ms step_avg:93.53ms +step:783/1750 train_time:73234ms step_avg:93.53ms +step:784/1750 train_time:73330ms step_avg:93.53ms +step:785/1750 train_time:73427ms step_avg:93.54ms +step:786/1750 train_time:73524ms step_avg:93.54ms +step:787/1750 train_time:73620ms step_avg:93.54ms +step:788/1750 train_time:73715ms step_avg:93.55ms +step:789/1750 train_time:73811ms step_avg:93.55ms +step:790/1750 train_time:73907ms step_avg:93.55ms +step:791/1750 train_time:74003ms step_avg:93.56ms +step:792/1750 train_time:74098ms step_avg:93.56ms +step:793/1750 train_time:74195ms step_avg:93.56ms +step:794/1750 train_time:74291ms step_avg:93.57ms +step:795/1750 train_time:74388ms step_avg:93.57ms +step:796/1750 train_time:74485ms step_avg:93.57ms +step:797/1750 train_time:74581ms step_avg:93.58ms +step:798/1750 train_time:74677ms step_avg:93.58ms +step:799/1750 train_time:74773ms step_avg:93.58ms +step:800/1750 train_time:74869ms step_avg:93.59ms +step:801/1750 train_time:74965ms step_avg:93.59ms +step:802/1750 train_time:75061ms step_avg:93.59ms +step:803/1750 train_time:75157ms step_avg:93.59ms +step:804/1750 train_time:75253ms step_avg:93.60ms +step:805/1750 train_time:75349ms step_avg:93.60ms +step:806/1750 train_time:75445ms step_avg:93.60ms +step:807/1750 train_time:75541ms step_avg:93.61ms +step:808/1750 train_time:75638ms step_avg:93.61ms +step:809/1750 train_time:75734ms step_avg:93.61ms +step:810/1750 train_time:75830ms step_avg:93.62ms +step:811/1750 train_time:75926ms step_avg:93.62ms +step:812/1750 train_time:76022ms step_avg:93.62ms +step:813/1750 train_time:76118ms step_avg:93.63ms +step:814/1750 train_time:76214ms step_avg:93.63ms +step:815/1750 train_time:76311ms step_avg:93.63ms +step:816/1750 train_time:76408ms step_avg:93.64ms +step:817/1750 train_time:76505ms step_avg:93.64ms +step:818/1750 train_time:76601ms step_avg:93.64ms +step:819/1750 train_time:76697ms step_avg:93.65ms +step:820/1750 train_time:76793ms step_avg:93.65ms +step:821/1750 train_time:76889ms step_avg:93.65ms +step:822/1750 train_time:76986ms step_avg:93.66ms +step:823/1750 train_time:77082ms step_avg:93.66ms +step:824/1750 train_time:77178ms step_avg:93.66ms +step:825/1750 train_time:77274ms step_avg:93.67ms +step:826/1750 train_time:77370ms step_avg:93.67ms +step:827/1750 train_time:77467ms step_avg:93.67ms +step:828/1750 train_time:77563ms step_avg:93.68ms +step:829/1750 train_time:77659ms step_avg:93.68ms +step:830/1750 train_time:77755ms step_avg:93.68ms +step:831/1750 train_time:77852ms step_avg:93.68ms +step:832/1750 train_time:77949ms step_avg:93.69ms +step:833/1750 train_time:78045ms step_avg:93.69ms +step:834/1750 train_time:78141ms step_avg:93.69ms +step:835/1750 train_time:78237ms step_avg:93.70ms +step:836/1750 train_time:78334ms step_avg:93.70ms +step:837/1750 train_time:78431ms step_avg:93.70ms +step:838/1750 train_time:78527ms step_avg:93.71ms +step:839/1750 train_time:78623ms step_avg:93.71ms +step:840/1750 train_time:78720ms step_avg:93.71ms +step:841/1750 train_time:78817ms step_avg:93.72ms +step:842/1750 train_time:78914ms step_avg:93.72ms +step:843/1750 train_time:79010ms step_avg:93.72ms +step:844/1750 train_time:79106ms step_avg:93.73ms +step:845/1750 train_time:79203ms step_avg:93.73ms +step:846/1750 train_time:79299ms step_avg:93.73ms +step:847/1750 train_time:79395ms step_avg:93.74ms +step:848/1750 train_time:79492ms step_avg:93.74ms +step:849/1750 train_time:79588ms step_avg:93.74ms +step:850/1750 train_time:79685ms step_avg:93.75ms +step:851/1750 train_time:79781ms step_avg:93.75ms +step:852/1750 train_time:79877ms step_avg:93.75ms +step:853/1750 train_time:79973ms step_avg:93.76ms +step:854/1750 train_time:80070ms step_avg:93.76ms +step:855/1750 train_time:80166ms step_avg:93.76ms +step:856/1750 train_time:80262ms step_avg:93.76ms +step:857/1750 train_time:80358ms step_avg:93.77ms +step:858/1750 train_time:80454ms step_avg:93.77ms +step:859/1750 train_time:80550ms step_avg:93.77ms +step:860/1750 train_time:80646ms step_avg:93.77ms +step:861/1750 train_time:80743ms step_avg:93.78ms +step:862/1750 train_time:80839ms step_avg:93.78ms +step:863/1750 train_time:80935ms step_avg:93.78ms +step:864/1750 train_time:81031ms step_avg:93.79ms +step:865/1750 train_time:81127ms step_avg:93.79ms +step:866/1750 train_time:81224ms step_avg:93.79ms +step:867/1750 train_time:81320ms step_avg:93.79ms +step:868/1750 train_time:81416ms step_avg:93.80ms +step:869/1750 train_time:81512ms step_avg:93.80ms +step:870/1750 train_time:81608ms step_avg:93.80ms +step:871/1750 train_time:81705ms step_avg:93.81ms +step:872/1750 train_time:81801ms step_avg:93.81ms +step:873/1750 train_time:81898ms step_avg:93.81ms +step:874/1750 train_time:81995ms step_avg:93.82ms +step:875/1750 train_time:82091ms step_avg:93.82ms +step:875/1750 val_loss:3.5431 train_time:82190ms step_avg:93.93ms +step:876/1750 train_time:82213ms step_avg:93.85ms +step:877/1750 train_time:82295ms step_avg:93.84ms +step:878/1750 train_time:82396ms step_avg:93.84ms +step:879/1750 train_time:82492ms step_avg:93.85ms +step:880/1750 train_time:82588ms step_avg:93.85ms +step:881/1750 train_time:82683ms step_avg:93.85ms +step:882/1750 train_time:82779ms step_avg:93.85ms +step:883/1750 train_time:82874ms step_avg:93.85ms +step:884/1750 train_time:82969ms step_avg:93.86ms +step:885/1750 train_time:83064ms step_avg:93.86ms +step:886/1750 train_time:83160ms step_avg:93.86ms +step:887/1750 train_time:83259ms step_avg:93.87ms +step:888/1750 train_time:83358ms step_avg:93.87ms +step:889/1750 train_time:83456ms step_avg:93.88ms +step:890/1750 train_time:83552ms step_avg:93.88ms +step:891/1750 train_time:83648ms step_avg:93.88ms +step:892/1750 train_time:83744ms step_avg:93.88ms +step:893/1750 train_time:83839ms step_avg:93.88ms +step:894/1750 train_time:83933ms step_avg:93.89ms +step:895/1750 train_time:84029ms step_avg:93.89ms +step:896/1750 train_time:84124ms step_avg:93.89ms +step:897/1750 train_time:84221ms step_avg:93.89ms +step:898/1750 train_time:84319ms step_avg:93.90ms +step:899/1750 train_time:84416ms step_avg:93.90ms +step:900/1750 train_time:84513ms step_avg:93.90ms +step:901/1750 train_time:84610ms step_avg:93.91ms +step:902/1750 train_time:84706ms step_avg:93.91ms +step:903/1750 train_time:84801ms step_avg:93.91ms +step:904/1750 train_time:84897ms step_avg:93.91ms +step:905/1750 train_time:84992ms step_avg:93.91ms +step:906/1750 train_time:85088ms step_avg:93.92ms +step:907/1750 train_time:85184ms step_avg:93.92ms +step:908/1750 train_time:85281ms step_avg:93.92ms +step:909/1750 train_time:85378ms step_avg:93.92ms +step:910/1750 train_time:85476ms step_avg:93.93ms +step:911/1750 train_time:85574ms step_avg:93.93ms +step:912/1750 train_time:85671ms step_avg:93.94ms +step:913/1750 train_time:85769ms step_avg:93.94ms +step:914/1750 train_time:85867ms step_avg:93.95ms +step:915/1750 train_time:85964ms step_avg:93.95ms +step:916/1750 train_time:86061ms step_avg:93.95ms +step:917/1750 train_time:86158ms step_avg:93.96ms +step:918/1750 train_time:86255ms step_avg:93.96ms +step:919/1750 train_time:86352ms step_avg:93.96ms +step:920/1750 train_time:86450ms step_avg:93.97ms +step:921/1750 train_time:86548ms step_avg:93.97ms +step:922/1750 train_time:86646ms step_avg:93.98ms +step:923/1750 train_time:86744ms step_avg:93.98ms +step:924/1750 train_time:86841ms step_avg:93.98ms +step:925/1750 train_time:86937ms step_avg:93.99ms +step:926/1750 train_time:87035ms step_avg:93.99ms +step:927/1750 train_time:87131ms step_avg:93.99ms +step:928/1750 train_time:87229ms step_avg:94.00ms +step:929/1750 train_time:87327ms step_avg:94.00ms +step:930/1750 train_time:87425ms step_avg:94.01ms +step:931/1750 train_time:87523ms step_avg:94.01ms +step:932/1750 train_time:87621ms step_avg:94.01ms +step:933/1750 train_time:87719ms step_avg:94.02ms +step:934/1750 train_time:87816ms step_avg:94.02ms +step:935/1750 train_time:87913ms step_avg:94.02ms +step:936/1750 train_time:88010ms step_avg:94.03ms +step:937/1750 train_time:88108ms step_avg:94.03ms +step:938/1750 train_time:88206ms step_avg:94.04ms +step:939/1750 train_time:88304ms step_avg:94.04ms +step:940/1750 train_time:88402ms step_avg:94.04ms +step:941/1750 train_time:88500ms step_avg:94.05ms +step:942/1750 train_time:88597ms step_avg:94.05ms +step:943/1750 train_time:88695ms step_avg:94.06ms +step:944/1750 train_time:88792ms step_avg:94.06ms +step:945/1750 train_time:88889ms step_avg:94.06ms +step:946/1750 train_time:88986ms step_avg:94.07ms +step:947/1750 train_time:89083ms step_avg:94.07ms +step:948/1750 train_time:89181ms step_avg:94.07ms +step:949/1750 train_time:89278ms step_avg:94.08ms +step:950/1750 train_time:89377ms step_avg:94.08ms +step:951/1750 train_time:89473ms step_avg:94.08ms +step:952/1750 train_time:89571ms step_avg:94.09ms +step:953/1750 train_time:89668ms step_avg:94.09ms +step:954/1750 train_time:89765ms step_avg:94.09ms +step:955/1750 train_time:89863ms step_avg:94.10ms +step:956/1750 train_time:89962ms step_avg:94.10ms +step:957/1750 train_time:90058ms step_avg:94.10ms +step:958/1750 train_time:90155ms step_avg:94.11ms +step:959/1750 train_time:90253ms step_avg:94.11ms +step:960/1750 train_time:90351ms step_avg:94.12ms +step:961/1750 train_time:90449ms step_avg:94.12ms +step:962/1750 train_time:90547ms step_avg:94.12ms +step:963/1750 train_time:90645ms step_avg:94.13ms +step:964/1750 train_time:90743ms step_avg:94.13ms +step:965/1750 train_time:90840ms step_avg:94.14ms +step:966/1750 train_time:90937ms step_avg:94.14ms +step:967/1750 train_time:91035ms step_avg:94.14ms +step:968/1750 train_time:91133ms step_avg:94.15ms +step:969/1750 train_time:91230ms step_avg:94.15ms +step:970/1750 train_time:91328ms step_avg:94.15ms +step:971/1750 train_time:91427ms step_avg:94.16ms +step:972/1750 train_time:91524ms step_avg:94.16ms +step:973/1750 train_time:91622ms step_avg:94.16ms +step:974/1750 train_time:91719ms step_avg:94.17ms +step:975/1750 train_time:91818ms step_avg:94.17ms +step:976/1750 train_time:91915ms step_avg:94.18ms +step:977/1750 train_time:92013ms step_avg:94.18ms +step:978/1750 train_time:92110ms step_avg:94.18ms +step:979/1750 train_time:92209ms step_avg:94.19ms +step:980/1750 train_time:92306ms step_avg:94.19ms +step:981/1750 train_time:92403ms step_avg:94.19ms +step:982/1750 train_time:92500ms step_avg:94.20ms +step:983/1750 train_time:92597ms step_avg:94.20ms +step:984/1750 train_time:92695ms step_avg:94.20ms +step:985/1750 train_time:92794ms step_avg:94.21ms +step:986/1750 train_time:92891ms step_avg:94.21ms +step:987/1750 train_time:92989ms step_avg:94.21ms +step:988/1750 train_time:93087ms step_avg:94.22ms +step:989/1750 train_time:93185ms step_avg:94.22ms +step:990/1750 train_time:93283ms step_avg:94.23ms +step:991/1750 train_time:93381ms step_avg:94.23ms +step:992/1750 train_time:93479ms step_avg:94.23ms +step:993/1750 train_time:93575ms step_avg:94.24ms +step:994/1750 train_time:93673ms step_avg:94.24ms +step:995/1750 train_time:93770ms step_avg:94.24ms +step:996/1750 train_time:93868ms step_avg:94.25ms +step:997/1750 train_time:93967ms step_avg:94.25ms +step:998/1750 train_time:94063ms step_avg:94.25ms +step:999/1750 train_time:94160ms step_avg:94.25ms +step:1000/1750 train_time:94258ms step_avg:94.26ms +step:1000/1750 val_loss:3.5036 train_time:94359ms step_avg:94.36ms +step:1001/1750 train_time:94383ms step_avg:94.29ms +step:1002/1750 train_time:94465ms step_avg:94.28ms +step:1003/1750 train_time:94566ms step_avg:94.28ms +step:1004/1750 train_time:94664ms step_avg:94.29ms +step:1005/1750 train_time:94761ms step_avg:94.29ms +step:1006/1750 train_time:94857ms step_avg:94.29ms +step:1007/1750 train_time:94953ms step_avg:94.29ms +step:1008/1750 train_time:95049ms step_avg:94.29ms +step:1009/1750 train_time:95145ms step_avg:94.30ms +step:1010/1750 train_time:95242ms step_avg:94.30ms +step:1011/1750 train_time:95340ms step_avg:94.30ms +step:1012/1750 train_time:95441ms step_avg:94.31ms +step:1013/1750 train_time:95540ms step_avg:94.31ms +step:1014/1750 train_time:95639ms step_avg:94.32ms +step:1015/1750 train_time:95736ms step_avg:94.32ms +step:1016/1750 train_time:95834ms step_avg:94.32ms +step:1017/1750 train_time:95931ms step_avg:94.33ms +step:1018/1750 train_time:96027ms step_avg:94.33ms +step:1019/1750 train_time:96124ms step_avg:94.33ms +step:1020/1750 train_time:96220ms step_avg:94.33ms +step:1021/1750 train_time:96318ms step_avg:94.34ms +step:1022/1750 train_time:96417ms step_avg:94.34ms +step:1023/1750 train_time:96516ms step_avg:94.35ms +step:1024/1750 train_time:96615ms step_avg:94.35ms +step:1025/1750 train_time:96712ms step_avg:94.35ms +step:1026/1750 train_time:96810ms step_avg:94.36ms +step:1027/1750 train_time:96906ms step_avg:94.36ms +step:1028/1750 train_time:97004ms step_avg:94.36ms +step:1029/1750 train_time:97102ms step_avg:94.37ms +step:1030/1750 train_time:97199ms step_avg:94.37ms +step:1031/1750 train_time:97296ms step_avg:94.37ms +step:1032/1750 train_time:97394ms step_avg:94.37ms +step:1033/1750 train_time:97492ms step_avg:94.38ms +step:1034/1750 train_time:97589ms step_avg:94.38ms +step:1035/1750 train_time:97688ms step_avg:94.38ms +step:1036/1750 train_time:97786ms step_avg:94.39ms +step:1037/1750 train_time:97884ms step_avg:94.39ms +step:1038/1750 train_time:97981ms step_avg:94.39ms +step:1039/1750 train_time:98078ms step_avg:94.40ms +step:1040/1750 train_time:98175ms step_avg:94.40ms +step:1041/1750 train_time:98273ms step_avg:94.40ms +step:1042/1750 train_time:98370ms step_avg:94.40ms +step:1043/1750 train_time:98468ms step_avg:94.41ms +step:1044/1750 train_time:98566ms step_avg:94.41ms +step:1045/1750 train_time:98664ms step_avg:94.42ms +step:1046/1750 train_time:98763ms step_avg:94.42ms +step:1047/1750 train_time:98861ms step_avg:94.42ms +step:1048/1750 train_time:98958ms step_avg:94.43ms +step:1049/1750 train_time:99055ms step_avg:94.43ms +step:1050/1750 train_time:99152ms step_avg:94.43ms +step:1051/1750 train_time:99250ms step_avg:94.43ms +step:1052/1750 train_time:99347ms step_avg:94.44ms +step:1053/1750 train_time:99444ms step_avg:94.44ms +step:1054/1750 train_time:99543ms step_avg:94.44ms +step:1055/1750 train_time:99641ms step_avg:94.45ms +step:1056/1750 train_time:99739ms step_avg:94.45ms +step:1057/1750 train_time:99837ms step_avg:94.45ms +step:1058/1750 train_time:99935ms step_avg:94.46ms +step:1059/1750 train_time:100032ms step_avg:94.46ms +step:1060/1750 train_time:100129ms step_avg:94.46ms +step:1061/1750 train_time:100226ms step_avg:94.46ms +step:1062/1750 train_time:100324ms step_avg:94.47ms +step:1063/1750 train_time:100421ms step_avg:94.47ms +step:1064/1750 train_time:100520ms step_avg:94.47ms +step:1065/1750 train_time:100619ms step_avg:94.48ms +step:1066/1750 train_time:100717ms step_avg:94.48ms +step:1067/1750 train_time:100815ms step_avg:94.48ms +step:1068/1750 train_time:100912ms step_avg:94.49ms +step:1069/1750 train_time:101009ms step_avg:94.49ms +step:1070/1750 train_time:101107ms step_avg:94.49ms +step:1071/1750 train_time:101205ms step_avg:94.50ms +step:1072/1750 train_time:101303ms step_avg:94.50ms +step:1073/1750 train_time:101399ms step_avg:94.50ms +step:1074/1750 train_time:101497ms step_avg:94.50ms +step:1075/1750 train_time:101595ms step_avg:94.51ms +step:1076/1750 train_time:101693ms step_avg:94.51ms +step:1077/1750 train_time:101792ms step_avg:94.51ms +step:1078/1750 train_time:101889ms step_avg:94.52ms +step:1079/1750 train_time:101988ms step_avg:94.52ms +step:1080/1750 train_time:102085ms step_avg:94.52ms +step:1081/1750 train_time:102182ms step_avg:94.53ms +step:1082/1750 train_time:102280ms step_avg:94.53ms +step:1083/1750 train_time:102378ms step_avg:94.53ms +step:1084/1750 train_time:102474ms step_avg:94.53ms +step:1085/1750 train_time:102572ms step_avg:94.54ms +step:1086/1750 train_time:102670ms step_avg:94.54ms +step:1087/1750 train_time:102767ms step_avg:94.54ms +step:1088/1750 train_time:102865ms step_avg:94.55ms +step:1089/1750 train_time:102963ms step_avg:94.55ms +step:1090/1750 train_time:103062ms step_avg:94.55ms +step:1091/1750 train_time:103160ms step_avg:94.56ms +step:1092/1750 train_time:103258ms step_avg:94.56ms +step:1093/1750 train_time:103356ms step_avg:94.56ms +step:1094/1750 train_time:103453ms step_avg:94.56ms +step:1095/1750 train_time:103550ms step_avg:94.57ms +step:1096/1750 train_time:103648ms step_avg:94.57ms +step:1097/1750 train_time:103746ms step_avg:94.57ms +step:1098/1750 train_time:103844ms step_avg:94.58ms +step:1099/1750 train_time:103942ms step_avg:94.58ms +step:1100/1750 train_time:104039ms step_avg:94.58ms +step:1101/1750 train_time:104137ms step_avg:94.58ms +step:1102/1750 train_time:104234ms step_avg:94.59ms +step:1103/1750 train_time:104332ms step_avg:94.59ms +step:1104/1750 train_time:104429ms step_avg:94.59ms +step:1105/1750 train_time:104526ms step_avg:94.59ms +step:1106/1750 train_time:104625ms step_avg:94.60ms +step:1107/1750 train_time:104723ms step_avg:94.60ms +step:1108/1750 train_time:104822ms step_avg:94.60ms +step:1109/1750 train_time:104920ms step_avg:94.61ms +step:1110/1750 train_time:105018ms step_avg:94.61ms +step:1111/1750 train_time:105115ms step_avg:94.61ms +step:1112/1750 train_time:105213ms step_avg:94.62ms +step:1113/1750 train_time:105310ms step_avg:94.62ms +step:1114/1750 train_time:105408ms step_avg:94.62ms +step:1115/1750 train_time:105505ms step_avg:94.62ms +step:1116/1750 train_time:105602ms step_avg:94.63ms +step:1117/1750 train_time:105700ms step_avg:94.63ms +step:1118/1750 train_time:105798ms step_avg:94.63ms +step:1119/1750 train_time:105895ms step_avg:94.63ms +step:1120/1750 train_time:105994ms step_avg:94.64ms +step:1121/1750 train_time:106092ms step_avg:94.64ms +step:1122/1750 train_time:106190ms step_avg:94.64ms +step:1123/1750 train_time:106287ms step_avg:94.65ms +step:1124/1750 train_time:106385ms step_avg:94.65ms +step:1125/1750 train_time:106484ms step_avg:94.65ms +step:1125/1750 val_loss:3.4519 train_time:106584ms step_avg:94.74ms +step:1126/1750 train_time:106609ms step_avg:94.68ms +step:1127/1750 train_time:106687ms step_avg:94.66ms +step:1128/1750 train_time:106787ms step_avg:94.67ms +step:1129/1750 train_time:106885ms step_avg:94.67ms +step:1130/1750 train_time:106983ms step_avg:94.67ms +step:1131/1750 train_time:107079ms step_avg:94.68ms +step:1132/1750 train_time:107176ms step_avg:94.68ms +step:1133/1750 train_time:107273ms step_avg:94.68ms +step:1134/1750 train_time:107371ms step_avg:94.68ms +step:1135/1750 train_time:107468ms step_avg:94.69ms +step:1136/1750 train_time:107567ms step_avg:94.69ms +step:1137/1750 train_time:107667ms step_avg:94.69ms +step:1138/1750 train_time:107766ms step_avg:94.70ms +step:1139/1750 train_time:107864ms step_avg:94.70ms +step:1140/1750 train_time:107961ms step_avg:94.70ms +step:1141/1750 train_time:108058ms step_avg:94.70ms +step:1142/1750 train_time:108154ms step_avg:94.71ms +step:1143/1750 train_time:108250ms step_avg:94.71ms +step:1144/1750 train_time:108347ms step_avg:94.71ms +step:1145/1750 train_time:108444ms step_avg:94.71ms +step:1146/1750 train_time:108543ms step_avg:94.71ms +step:1147/1750 train_time:108641ms step_avg:94.72ms +step:1148/1750 train_time:108740ms step_avg:94.72ms +step:1149/1750 train_time:108839ms step_avg:94.72ms +step:1150/1750 train_time:108936ms step_avg:94.73ms +step:1151/1750 train_time:109033ms step_avg:94.73ms +step:1152/1750 train_time:109130ms step_avg:94.73ms +step:1153/1750 train_time:109228ms step_avg:94.73ms +step:1154/1750 train_time:109324ms step_avg:94.73ms +step:1155/1750 train_time:109421ms step_avg:94.74ms +step:1156/1750 train_time:109519ms step_avg:94.74ms +step:1157/1750 train_time:109617ms step_avg:94.74ms +step:1158/1750 train_time:109715ms step_avg:94.75ms +step:1159/1750 train_time:109813ms step_avg:94.75ms +step:1160/1750 train_time:109911ms step_avg:94.75ms +step:1161/1750 train_time:110010ms step_avg:94.75ms +step:1162/1750 train_time:110108ms step_avg:94.76ms +step:1163/1750 train_time:110207ms step_avg:94.76ms +step:1164/1750 train_time:110304ms step_avg:94.76ms +step:1165/1750 train_time:110401ms step_avg:94.77ms +step:1166/1750 train_time:110499ms step_avg:94.77ms +step:1167/1750 train_time:110596ms step_avg:94.77ms +step:1168/1750 train_time:110695ms step_avg:94.77ms +step:1169/1750 train_time:110794ms step_avg:94.78ms +step:1170/1750 train_time:110892ms step_avg:94.78ms +step:1171/1750 train_time:110992ms step_avg:94.78ms +step:1172/1750 train_time:111090ms step_avg:94.79ms +step:1173/1750 train_time:111191ms step_avg:94.79ms +step:1174/1750 train_time:111291ms step_avg:94.80ms +step:1175/1750 train_time:111390ms step_avg:94.80ms +step:1176/1750 train_time:111489ms step_avg:94.80ms +step:1177/1750 train_time:111590ms step_avg:94.81ms +step:1178/1750 train_time:111689ms step_avg:94.81ms +step:1179/1750 train_time:111792ms step_avg:94.82ms +step:1180/1750 train_time:111891ms step_avg:94.82ms +step:1181/1750 train_time:111989ms step_avg:94.83ms +step:1182/1750 train_time:112088ms step_avg:94.83ms +step:1183/1750 train_time:112187ms step_avg:94.83ms +step:1184/1750 train_time:112286ms step_avg:94.84ms +step:1185/1750 train_time:112385ms step_avg:94.84ms +step:1186/1750 train_time:112485ms step_avg:94.84ms +step:1187/1750 train_time:112584ms step_avg:94.85ms +step:1188/1750 train_time:112683ms step_avg:94.85ms +step:1189/1750 train_time:112781ms step_avg:94.85ms +step:1190/1750 train_time:112881ms step_avg:94.86ms +step:1191/1750 train_time:112980ms step_avg:94.86ms +step:1192/1750 train_time:113080ms step_avg:94.87ms +step:1193/1750 train_time:113179ms step_avg:94.87ms +step:1194/1750 train_time:113278ms step_avg:94.87ms +step:1195/1750 train_time:113378ms step_avg:94.88ms +step:1196/1750 train_time:113476ms step_avg:94.88ms +step:1197/1750 train_time:113575ms step_avg:94.88ms +step:1198/1750 train_time:113675ms step_avg:94.89ms +step:1199/1750 train_time:113774ms step_avg:94.89ms +step:1200/1750 train_time:113873ms step_avg:94.89ms +step:1201/1750 train_time:113973ms step_avg:94.90ms +step:1202/1750 train_time:114073ms step_avg:94.90ms +step:1203/1750 train_time:114172ms step_avg:94.91ms +step:1204/1750 train_time:114271ms step_avg:94.91ms +step:1205/1750 train_time:114370ms step_avg:94.91ms +step:1206/1750 train_time:114469ms step_avg:94.92ms +step:1207/1750 train_time:114568ms step_avg:94.92ms +step:1208/1750 train_time:114668ms step_avg:94.92ms +step:1209/1750 train_time:114767ms step_avg:94.93ms +step:1210/1750 train_time:114866ms step_avg:94.93ms +step:1211/1750 train_time:114966ms step_avg:94.93ms +step:1212/1750 train_time:115064ms step_avg:94.94ms +step:1213/1750 train_time:115163ms step_avg:94.94ms +step:1214/1750 train_time:115261ms step_avg:94.94ms +step:1215/1750 train_time:115361ms step_avg:94.95ms +step:1216/1750 train_time:115460ms step_avg:94.95ms +step:1217/1750 train_time:115560ms step_avg:94.95ms +step:1218/1750 train_time:115660ms step_avg:94.96ms +step:1219/1750 train_time:115759ms step_avg:94.96ms +step:1220/1750 train_time:115859ms step_avg:94.97ms +step:1221/1750 train_time:115958ms step_avg:94.97ms +step:1222/1750 train_time:116057ms step_avg:94.97ms +step:1223/1750 train_time:116157ms step_avg:94.98ms +step:1224/1750 train_time:116256ms step_avg:94.98ms +step:1225/1750 train_time:116355ms step_avg:94.98ms +step:1226/1750 train_time:116455ms step_avg:94.99ms +step:1227/1750 train_time:116553ms step_avg:94.99ms +step:1228/1750 train_time:116652ms step_avg:94.99ms +step:1229/1750 train_time:116752ms step_avg:95.00ms +step:1230/1750 train_time:116850ms step_avg:95.00ms +step:1231/1750 train_time:116950ms step_avg:95.00ms +step:1232/1750 train_time:117049ms step_avg:95.01ms +step:1233/1750 train_time:117148ms step_avg:95.01ms +step:1234/1750 train_time:117249ms step_avg:95.02ms +step:1235/1750 train_time:117348ms step_avg:95.02ms +step:1236/1750 train_time:117447ms step_avg:95.02ms +step:1237/1750 train_time:117548ms step_avg:95.03ms +step:1238/1750 train_time:117647ms step_avg:95.03ms +step:1239/1750 train_time:117746ms step_avg:95.03ms +step:1240/1750 train_time:117845ms step_avg:95.04ms +step:1241/1750 train_time:117944ms step_avg:95.04ms +step:1242/1750 train_time:118042ms step_avg:95.04ms +step:1243/1750 train_time:118141ms step_avg:95.04ms +step:1244/1750 train_time:118241ms step_avg:95.05ms +step:1245/1750 train_time:118341ms step_avg:95.05ms +step:1246/1750 train_time:118441ms step_avg:95.06ms +step:1247/1750 train_time:118539ms step_avg:95.06ms +step:1248/1750 train_time:118638ms step_avg:95.06ms +step:1249/1750 train_time:118738ms step_avg:95.07ms +step:1250/1750 train_time:118836ms step_avg:95.07ms +step:1250/1750 val_loss:3.4067 train_time:118938ms step_avg:95.15ms +step:1251/1750 train_time:118962ms step_avg:95.09ms +step:1252/1750 train_time:119048ms step_avg:95.09ms +step:1253/1750 train_time:119149ms step_avg:95.09ms +step:1254/1750 train_time:119248ms step_avg:95.09ms +step:1255/1750 train_time:119347ms step_avg:95.10ms +step:1256/1750 train_time:119445ms step_avg:95.10ms +step:1257/1750 train_time:119542ms step_avg:95.10ms +step:1258/1750 train_time:119640ms step_avg:95.10ms +step:1259/1750 train_time:119737ms step_avg:95.10ms +step:1260/1750 train_time:119834ms step_avg:95.11ms +step:1261/1750 train_time:119932ms step_avg:95.11ms +step:1262/1750 train_time:120033ms step_avg:95.11ms +step:1263/1750 train_time:120133ms step_avg:95.12ms +step:1264/1750 train_time:120232ms step_avg:95.12ms +step:1265/1750 train_time:120332ms step_avg:95.12ms +step:1266/1750 train_time:120430ms step_avg:95.13ms +step:1267/1750 train_time:120529ms step_avg:95.13ms +step:1268/1750 train_time:120627ms step_avg:95.13ms +step:1269/1750 train_time:120724ms step_avg:95.13ms +step:1270/1750 train_time:120822ms step_avg:95.14ms +step:1271/1750 train_time:120924ms step_avg:95.14ms +step:1272/1750 train_time:121023ms step_avg:95.14ms +step:1273/1750 train_time:121123ms step_avg:95.15ms +step:1274/1750 train_time:121221ms step_avg:95.15ms +step:1275/1750 train_time:121320ms step_avg:95.15ms +step:1276/1750 train_time:121420ms step_avg:95.16ms +step:1277/1750 train_time:121519ms step_avg:95.16ms +step:1278/1750 train_time:121619ms step_avg:95.16ms +step:1279/1750 train_time:121718ms step_avg:95.17ms +step:1280/1750 train_time:121817ms step_avg:95.17ms +step:1281/1750 train_time:121917ms step_avg:95.17ms +step:1282/1750 train_time:122016ms step_avg:95.18ms +step:1283/1750 train_time:122115ms step_avg:95.18ms +step:1284/1750 train_time:122213ms step_avg:95.18ms +step:1285/1750 train_time:122311ms step_avg:95.18ms +step:1286/1750 train_time:122410ms step_avg:95.19ms +step:1287/1750 train_time:122509ms step_avg:95.19ms +step:1288/1750 train_time:122609ms step_avg:95.19ms +step:1289/1750 train_time:122709ms step_avg:95.20ms +step:1290/1750 train_time:122807ms step_avg:95.20ms +step:1291/1750 train_time:122906ms step_avg:95.20ms +step:1292/1750 train_time:123006ms step_avg:95.21ms +step:1293/1750 train_time:123105ms step_avg:95.21ms +step:1294/1750 train_time:123206ms step_avg:95.21ms +step:1295/1750 train_time:123305ms step_avg:95.22ms +step:1296/1750 train_time:123405ms step_avg:95.22ms +step:1297/1750 train_time:123504ms step_avg:95.22ms +step:1298/1750 train_time:123602ms step_avg:95.23ms +step:1299/1750 train_time:123701ms step_avg:95.23ms +step:1300/1750 train_time:123800ms step_avg:95.23ms +step:1301/1750 train_time:123899ms step_avg:95.23ms +step:1302/1750 train_time:123997ms step_avg:95.24ms +step:1303/1750 train_time:124097ms step_avg:95.24ms +step:1304/1750 train_time:124197ms step_avg:95.24ms +step:1305/1750 train_time:124296ms step_avg:95.25ms +step:1306/1750 train_time:124395ms step_avg:95.25ms +step:1307/1750 train_time:124496ms step_avg:95.25ms +step:1308/1750 train_time:124595ms step_avg:95.26ms +step:1309/1750 train_time:124694ms step_avg:95.26ms +step:1310/1750 train_time:124794ms step_avg:95.26ms +step:1311/1750 train_time:124894ms step_avg:95.27ms +step:1312/1750 train_time:124993ms step_avg:95.27ms +step:1313/1750 train_time:125093ms step_avg:95.27ms +step:1314/1750 train_time:125193ms step_avg:95.28ms +step:1315/1750 train_time:125292ms step_avg:95.28ms +step:1316/1750 train_time:125391ms step_avg:95.28ms +step:1317/1750 train_time:125490ms step_avg:95.28ms +step:1318/1750 train_time:125588ms step_avg:95.29ms +step:1319/1750 train_time:125687ms step_avg:95.29ms +step:1320/1750 train_time:125788ms step_avg:95.29ms +step:1321/1750 train_time:125887ms step_avg:95.30ms +step:1322/1750 train_time:125986ms step_avg:95.30ms +step:1323/1750 train_time:126085ms step_avg:95.30ms +step:1324/1750 train_time:126186ms step_avg:95.31ms +step:1325/1750 train_time:126285ms step_avg:95.31ms +step:1326/1750 train_time:126384ms step_avg:95.31ms +step:1327/1750 train_time:126482ms step_avg:95.31ms +step:1328/1750 train_time:126581ms step_avg:95.32ms +step:1329/1750 train_time:126680ms step_avg:95.32ms +step:1330/1750 train_time:126780ms step_avg:95.32ms +step:1331/1750 train_time:126878ms step_avg:95.33ms +step:1332/1750 train_time:126977ms step_avg:95.33ms +step:1333/1750 train_time:127078ms step_avg:95.33ms +step:1334/1750 train_time:127177ms step_avg:95.34ms +step:1335/1750 train_time:127277ms step_avg:95.34ms +step:1336/1750 train_time:127376ms step_avg:95.34ms +step:1337/1750 train_time:127475ms step_avg:95.34ms +step:1338/1750 train_time:127574ms step_avg:95.35ms +step:1339/1750 train_time:127673ms step_avg:95.35ms +step:1340/1750 train_time:127773ms step_avg:95.35ms +step:1341/1750 train_time:127872ms step_avg:95.36ms +step:1342/1750 train_time:127972ms step_avg:95.36ms +step:1343/1750 train_time:128072ms step_avg:95.36ms +step:1344/1750 train_time:128171ms step_avg:95.37ms +step:1345/1750 train_time:128271ms step_avg:95.37ms +step:1346/1750 train_time:128370ms step_avg:95.37ms +step:1347/1750 train_time:128469ms step_avg:95.37ms +step:1348/1750 train_time:128569ms step_avg:95.38ms +step:1349/1750 train_time:128666ms step_avg:95.38ms +step:1350/1750 train_time:128767ms step_avg:95.38ms +step:1351/1750 train_time:128865ms step_avg:95.38ms +step:1352/1750 train_time:128964ms step_avg:95.39ms +step:1353/1750 train_time:129064ms step_avg:95.39ms +step:1354/1750 train_time:129163ms step_avg:95.39ms +step:1355/1750 train_time:129262ms step_avg:95.40ms +step:1356/1750 train_time:129362ms step_avg:95.40ms +step:1357/1750 train_time:129460ms step_avg:95.40ms +step:1358/1750 train_time:129560ms step_avg:95.40ms +step:1359/1750 train_time:129658ms step_avg:95.41ms +step:1360/1750 train_time:129757ms step_avg:95.41ms +step:1361/1750 train_time:129856ms step_avg:95.41ms +step:1362/1750 train_time:129957ms step_avg:95.42ms +step:1363/1750 train_time:130058ms step_avg:95.42ms +step:1364/1750 train_time:130157ms step_avg:95.42ms +step:1365/1750 train_time:130257ms step_avg:95.43ms +step:1366/1750 train_time:130355ms step_avg:95.43ms +step:1367/1750 train_time:130453ms step_avg:95.43ms +step:1368/1750 train_time:130552ms step_avg:95.43ms +step:1369/1750 train_time:130650ms step_avg:95.43ms +step:1370/1750 train_time:130749ms step_avg:95.44ms +step:1371/1750 train_time:130849ms step_avg:95.44ms +step:1372/1750 train_time:130948ms step_avg:95.44ms +step:1373/1750 train_time:131048ms step_avg:95.45ms +step:1374/1750 train_time:131147ms step_avg:95.45ms +step:1375/1750 train_time:131248ms step_avg:95.45ms +step:1375/1750 val_loss:3.3669 train_time:131351ms step_avg:95.53ms +step:1376/1750 train_time:131375ms step_avg:95.48ms +step:1377/1750 train_time:131458ms step_avg:95.47ms +step:1378/1750 train_time:131559ms step_avg:95.47ms +step:1379/1750 train_time:131659ms step_avg:95.47ms +step:1380/1750 train_time:131759ms step_avg:95.48ms +step:1381/1750 train_time:131857ms step_avg:95.48ms +step:1382/1750 train_time:131955ms step_avg:95.48ms +step:1383/1750 train_time:132053ms step_avg:95.48ms +step:1384/1750 train_time:132150ms step_avg:95.48ms +step:1385/1750 train_time:132248ms step_avg:95.49ms +step:1386/1750 train_time:132348ms step_avg:95.49ms +step:1387/1750 train_time:132449ms step_avg:95.49ms +step:1388/1750 train_time:132549ms step_avg:95.50ms +step:1389/1750 train_time:132649ms step_avg:95.50ms +step:1390/1750 train_time:132748ms step_avg:95.50ms +step:1391/1750 train_time:132846ms step_avg:95.50ms +step:1392/1750 train_time:132945ms step_avg:95.51ms +step:1393/1750 train_time:133043ms step_avg:95.51ms +step:1394/1750 train_time:133142ms step_avg:95.51ms +step:1395/1750 train_time:133241ms step_avg:95.51ms +step:1396/1750 train_time:133340ms step_avg:95.52ms +step:1397/1750 train_time:133440ms step_avg:95.52ms +step:1398/1750 train_time:133540ms step_avg:95.52ms +step:1399/1750 train_time:133640ms step_avg:95.53ms +step:1400/1750 train_time:133739ms step_avg:95.53ms +step:1401/1750 train_time:133838ms step_avg:95.53ms +step:1402/1750 train_time:133938ms step_avg:95.53ms +step:1403/1750 train_time:134038ms step_avg:95.54ms +step:1404/1750 train_time:134137ms step_avg:95.54ms +step:1405/1750 train_time:134235ms step_avg:95.54ms +step:1406/1750 train_time:134335ms step_avg:95.54ms +step:1407/1750 train_time:134435ms step_avg:95.55ms +step:1408/1750 train_time:134535ms step_avg:95.55ms +step:1409/1750 train_time:134635ms step_avg:95.55ms +step:1410/1750 train_time:134734ms step_avg:95.56ms +step:1411/1750 train_time:134834ms step_avg:95.56ms +step:1412/1750 train_time:134934ms step_avg:95.56ms +step:1413/1750 train_time:135032ms step_avg:95.56ms +step:1414/1750 train_time:135132ms step_avg:95.57ms +step:1415/1750 train_time:135232ms step_avg:95.57ms +step:1416/1750 train_time:135330ms step_avg:95.57ms +step:1417/1750 train_time:135429ms step_avg:95.57ms +step:1418/1750 train_time:135528ms step_avg:95.58ms +step:1419/1750 train_time:135627ms step_avg:95.58ms +step:1420/1750 train_time:135727ms step_avg:95.58ms +step:1421/1750 train_time:135826ms step_avg:95.58ms +step:1422/1750 train_time:135925ms step_avg:95.59ms +step:1423/1750 train_time:136025ms step_avg:95.59ms +step:1424/1750 train_time:136126ms step_avg:95.59ms +step:1425/1750 train_time:136225ms step_avg:95.60ms +step:1426/1750 train_time:136324ms step_avg:95.60ms +step:1427/1750 train_time:136423ms step_avg:95.60ms +step:1428/1750 train_time:136523ms step_avg:95.60ms +step:1429/1750 train_time:136622ms step_avg:95.61ms +step:1430/1750 train_time:136722ms step_avg:95.61ms +step:1431/1750 train_time:136822ms step_avg:95.61ms +step:1432/1750 train_time:136922ms step_avg:95.62ms +step:1433/1750 train_time:137022ms step_avg:95.62ms +step:1434/1750 train_time:137122ms step_avg:95.62ms +step:1435/1750 train_time:137223ms step_avg:95.63ms +step:1436/1750 train_time:137324ms step_avg:95.63ms +step:1437/1750 train_time:137425ms step_avg:95.63ms +step:1438/1750 train_time:137524ms step_avg:95.64ms +step:1439/1750 train_time:137624ms step_avg:95.64ms +step:1440/1750 train_time:137726ms step_avg:95.64ms +step:1441/1750 train_time:137827ms step_avg:95.65ms +step:1442/1750 train_time:137925ms step_avg:95.65ms +step:1443/1750 train_time:138024ms step_avg:95.65ms +step:1444/1750 train_time:138125ms step_avg:95.65ms +step:1445/1750 train_time:138224ms step_avg:95.66ms +step:1446/1750 train_time:138324ms step_avg:95.66ms +step:1447/1750 train_time:138424ms step_avg:95.66ms +step:1448/1750 train_time:138524ms step_avg:95.67ms +step:1449/1750 train_time:138624ms step_avg:95.67ms +step:1450/1750 train_time:138723ms step_avg:95.67ms +step:1451/1750 train_time:138823ms step_avg:95.67ms +step:1452/1750 train_time:138922ms step_avg:95.68ms +step:1453/1750 train_time:139023ms step_avg:95.68ms +step:1454/1750 train_time:139125ms step_avg:95.68ms +step:1455/1750 train_time:139225ms step_avg:95.69ms +step:1456/1750 train_time:139324ms step_avg:95.69ms +step:1457/1750 train_time:139425ms step_avg:95.69ms +step:1458/1750 train_time:139525ms step_avg:95.70ms +step:1459/1750 train_time:139624ms step_avg:95.70ms +step:1460/1750 train_time:139724ms step_avg:95.70ms +step:1461/1750 train_time:139823ms step_avg:95.70ms +step:1462/1750 train_time:139922ms step_avg:95.71ms +step:1463/1750 train_time:140022ms step_avg:95.71ms +step:1464/1750 train_time:140122ms step_avg:95.71ms +step:1465/1750 train_time:140222ms step_avg:95.71ms +step:1466/1750 train_time:140321ms step_avg:95.72ms +step:1467/1750 train_time:140422ms step_avg:95.72ms +step:1468/1750 train_time:140522ms step_avg:95.72ms +step:1469/1750 train_time:140624ms step_avg:95.73ms +step:1470/1750 train_time:140724ms step_avg:95.73ms +step:1471/1750 train_time:140823ms step_avg:95.73ms +step:1472/1750 train_time:140923ms step_avg:95.74ms +step:1473/1750 train_time:141023ms step_avg:95.74ms +step:1474/1750 train_time:141123ms step_avg:95.74ms +step:1475/1750 train_time:141222ms step_avg:95.74ms +step:1476/1750 train_time:141323ms step_avg:95.75ms +step:1477/1750 train_time:141423ms step_avg:95.75ms +step:1478/1750 train_time:141524ms step_avg:95.75ms +step:1479/1750 train_time:141625ms step_avg:95.76ms +step:1480/1750 train_time:141725ms step_avg:95.76ms +step:1481/1750 train_time:141824ms step_avg:95.76ms +step:1482/1750 train_time:141925ms step_avg:95.77ms +step:1483/1750 train_time:142024ms step_avg:95.77ms +step:1484/1750 train_time:142125ms step_avg:95.77ms +step:1485/1750 train_time:142227ms step_avg:95.78ms +step:1486/1750 train_time:142327ms step_avg:95.78ms +step:1487/1750 train_time:142428ms step_avg:95.78ms +step:1488/1750 train_time:142530ms step_avg:95.79ms +step:1489/1750 train_time:142629ms step_avg:95.79ms +step:1490/1750 train_time:142729ms step_avg:95.79ms +step:1491/1750 train_time:142829ms step_avg:95.79ms +step:1492/1750 train_time:142929ms step_avg:95.80ms +step:1493/1750 train_time:143029ms step_avg:95.80ms +step:1494/1750 train_time:143129ms step_avg:95.80ms +step:1495/1750 train_time:143229ms step_avg:95.81ms +step:1496/1750 train_time:143329ms step_avg:95.81ms +step:1497/1750 train_time:143429ms step_avg:95.81ms +step:1498/1750 train_time:143529ms step_avg:95.81ms +step:1499/1750 train_time:143629ms step_avg:95.82ms +step:1500/1750 train_time:143730ms step_avg:95.82ms +step:1500/1750 val_loss:3.3314 train_time:143832ms step_avg:95.89ms +step:1501/1750 train_time:143856ms step_avg:95.84ms +step:1502/1750 train_time:143941ms step_avg:95.83ms +step:1503/1750 train_time:144043ms step_avg:95.84ms +step:1504/1750 train_time:144142ms step_avg:95.84ms +step:1505/1750 train_time:144241ms step_avg:95.84ms +step:1506/1750 train_time:144340ms step_avg:95.84ms +step:1507/1750 train_time:144438ms step_avg:95.84ms +step:1508/1750 train_time:144537ms step_avg:95.85ms +step:1509/1750 train_time:144635ms step_avg:95.85ms +step:1510/1750 train_time:144734ms step_avg:95.85ms +step:1511/1750 train_time:144835ms step_avg:95.85ms +step:1512/1750 train_time:144937ms step_avg:95.86ms +step:1513/1750 train_time:145038ms step_avg:95.86ms +step:1514/1750 train_time:145139ms step_avg:95.86ms +step:1515/1750 train_time:145243ms step_avg:95.87ms +step:1516/1750 train_time:145342ms step_avg:95.87ms +step:1517/1750 train_time:145441ms step_avg:95.87ms +step:1518/1750 train_time:145540ms step_avg:95.88ms +step:1519/1750 train_time:145640ms step_avg:95.88ms +step:1520/1750 train_time:145741ms step_avg:95.88ms +step:1521/1750 train_time:145842ms step_avg:95.89ms +step:1522/1750 train_time:145943ms step_avg:95.89ms +step:1523/1750 train_time:146042ms step_avg:95.89ms +step:1524/1750 train_time:146143ms step_avg:95.89ms +step:1525/1750 train_time:146243ms step_avg:95.90ms +step:1526/1750 train_time:146342ms step_avg:95.90ms +step:1527/1750 train_time:146442ms step_avg:95.90ms +step:1528/1750 train_time:146544ms step_avg:95.91ms +step:1529/1750 train_time:146644ms step_avg:95.91ms +step:1530/1750 train_time:146744ms step_avg:95.91ms +step:1531/1750 train_time:146844ms step_avg:95.91ms +step:1532/1750 train_time:146945ms step_avg:95.92ms +step:1533/1750 train_time:147044ms step_avg:95.92ms +step:1534/1750 train_time:147144ms step_avg:95.92ms +step:1535/1750 train_time:147244ms step_avg:95.92ms +step:1536/1750 train_time:147343ms step_avg:95.93ms +step:1537/1750 train_time:147442ms step_avg:95.93ms +step:1538/1750 train_time:147543ms step_avg:95.93ms +step:1539/1750 train_time:147642ms step_avg:95.93ms +step:1540/1750 train_time:147742ms step_avg:95.94ms +step:1541/1750 train_time:147844ms step_avg:95.94ms +step:1542/1750 train_time:147947ms step_avg:95.95ms +step:1543/1750 train_time:148049ms step_avg:95.95ms +step:1544/1750 train_time:148148ms step_avg:95.95ms +step:1545/1750 train_time:148249ms step_avg:95.95ms +step:1546/1750 train_time:148348ms step_avg:95.96ms +step:1547/1750 train_time:148452ms step_avg:95.96ms +step:1548/1750 train_time:148553ms step_avg:95.96ms +step:1549/1750 train_time:148654ms step_avg:95.97ms +step:1550/1750 train_time:148755ms step_avg:95.97ms +step:1551/1750 train_time:148855ms step_avg:95.97ms +step:1552/1750 train_time:148957ms step_avg:95.98ms +step:1553/1750 train_time:149056ms step_avg:95.98ms +step:1554/1750 train_time:149155ms step_avg:95.98ms +step:1555/1750 train_time:149256ms step_avg:95.98ms +step:1556/1750 train_time:149357ms step_avg:95.99ms +step:1557/1750 train_time:149457ms step_avg:95.99ms +step:1558/1750 train_time:149557ms step_avg:95.99ms +step:1559/1750 train_time:149657ms step_avg:96.00ms +step:1560/1750 train_time:149757ms step_avg:96.00ms +step:1561/1750 train_time:149858ms step_avg:96.00ms +step:1562/1750 train_time:149959ms step_avg:96.00ms +step:1563/1750 train_time:150062ms step_avg:96.01ms +step:1564/1750 train_time:150162ms step_avg:96.01ms +step:1565/1750 train_time:150262ms step_avg:96.01ms +step:1566/1750 train_time:150362ms step_avg:96.02ms +step:1567/1750 train_time:150461ms step_avg:96.02ms +step:1568/1750 train_time:150561ms step_avg:96.02ms +step:1569/1750 train_time:150662ms step_avg:96.02ms +step:1570/1750 train_time:150764ms step_avg:96.03ms +step:1571/1750 train_time:150864ms step_avg:96.03ms +step:1572/1750 train_time:150964ms step_avg:96.03ms +step:1573/1750 train_time:151065ms step_avg:96.04ms +step:1574/1750 train_time:151165ms step_avg:96.04ms +step:1575/1750 train_time:151266ms step_avg:96.04ms +step:1576/1750 train_time:151366ms step_avg:96.04ms +step:1577/1750 train_time:151468ms step_avg:96.05ms +step:1578/1750 train_time:151568ms step_avg:96.05ms +step:1579/1750 train_time:151668ms step_avg:96.05ms +step:1580/1750 train_time:151771ms step_avg:96.06ms +step:1581/1750 train_time:151871ms step_avg:96.06ms +step:1582/1750 train_time:151970ms step_avg:96.06ms +step:1583/1750 train_time:152072ms step_avg:96.07ms +step:1584/1750 train_time:152173ms step_avg:96.07ms +step:1585/1750 train_time:152272ms step_avg:96.07ms +step:1586/1750 train_time:152372ms step_avg:96.07ms +step:1587/1750 train_time:152472ms step_avg:96.08ms +step:1588/1750 train_time:152573ms step_avg:96.08ms +step:1589/1750 train_time:152672ms step_avg:96.08ms +step:1590/1750 train_time:152773ms step_avg:96.08ms +step:1591/1750 train_time:152873ms step_avg:96.09ms +step:1592/1750 train_time:152975ms step_avg:96.09ms +step:1593/1750 train_time:153074ms step_avg:96.09ms +step:1594/1750 train_time:153178ms step_avg:96.10ms +step:1595/1750 train_time:153277ms step_avg:96.10ms +step:1596/1750 train_time:153377ms step_avg:96.10ms +step:1597/1750 train_time:153476ms step_avg:96.10ms +step:1598/1750 train_time:153579ms step_avg:96.11ms +step:1599/1750 train_time:153679ms step_avg:96.11ms +step:1600/1750 train_time:153781ms step_avg:96.11ms +step:1601/1750 train_time:153882ms step_avg:96.12ms +step:1602/1750 train_time:153982ms step_avg:96.12ms +step:1603/1750 train_time:154082ms step_avg:96.12ms +step:1604/1750 train_time:154182ms step_avg:96.12ms +step:1605/1750 train_time:154282ms step_avg:96.13ms +step:1606/1750 train_time:154381ms step_avg:96.13ms +step:1607/1750 train_time:154482ms step_avg:96.13ms +step:1608/1750 train_time:154581ms step_avg:96.13ms +step:1609/1750 train_time:154681ms step_avg:96.13ms +step:1610/1750 train_time:154782ms step_avg:96.14ms +step:1611/1750 train_time:154883ms step_avg:96.14ms +step:1612/1750 train_time:154984ms step_avg:96.14ms +step:1613/1750 train_time:155084ms step_avg:96.15ms +step:1614/1750 train_time:155184ms step_avg:96.15ms +step:1615/1750 train_time:155285ms step_avg:96.15ms +step:1616/1750 train_time:155385ms step_avg:96.15ms +step:1617/1750 train_time:155485ms step_avg:96.16ms +step:1618/1750 train_time:155586ms step_avg:96.16ms +step:1619/1750 train_time:155686ms step_avg:96.16ms +step:1620/1750 train_time:155786ms step_avg:96.16ms +step:1621/1750 train_time:155887ms step_avg:96.17ms +step:1622/1750 train_time:155986ms step_avg:96.17ms +step:1623/1750 train_time:156086ms step_avg:96.17ms +step:1624/1750 train_time:156187ms step_avg:96.17ms +step:1625/1750 train_time:156290ms step_avg:96.18ms +step:1625/1750 val_loss:3.3014 train_time:156394ms step_avg:96.24ms +step:1626/1750 train_time:156420ms step_avg:96.20ms +step:1627/1750 train_time:156503ms step_avg:96.19ms +step:1628/1750 train_time:156603ms step_avg:96.19ms +step:1629/1750 train_time:156703ms step_avg:96.20ms +step:1630/1750 train_time:156802ms step_avg:96.20ms +step:1631/1750 train_time:156902ms step_avg:96.20ms +step:1632/1750 train_time:157002ms step_avg:96.20ms +step:1633/1750 train_time:157101ms step_avg:96.20ms +step:1634/1750 train_time:157202ms step_avg:96.21ms +step:1635/1750 train_time:157302ms step_avg:96.21ms +step:1636/1750 train_time:157404ms step_avg:96.21ms +step:1637/1750 train_time:157507ms step_avg:96.22ms +step:1638/1750 train_time:157608ms step_avg:96.22ms +step:1639/1750 train_time:157709ms step_avg:96.22ms +step:1640/1750 train_time:157808ms step_avg:96.22ms +step:1641/1750 train_time:157907ms step_avg:96.23ms +step:1642/1750 train_time:158006ms step_avg:96.23ms +step:1643/1750 train_time:158105ms step_avg:96.23ms +step:1644/1750 train_time:158205ms step_avg:96.23ms +step:1645/1750 train_time:158305ms step_avg:96.23ms +step:1646/1750 train_time:158406ms step_avg:96.24ms +step:1647/1750 train_time:158509ms step_avg:96.24ms +step:1648/1750 train_time:158611ms step_avg:96.24ms +step:1649/1750 train_time:158712ms step_avg:96.25ms +step:1650/1750 train_time:158811ms step_avg:96.25ms +step:1651/1750 train_time:158910ms step_avg:96.25ms +step:1652/1750 train_time:159010ms step_avg:96.25ms +step:1653/1750 train_time:159111ms step_avg:96.26ms +step:1654/1750 train_time:159211ms step_avg:96.26ms +step:1655/1750 train_time:159311ms step_avg:96.26ms +step:1656/1750 train_time:159412ms step_avg:96.26ms +step:1657/1750 train_time:159512ms step_avg:96.27ms +step:1658/1750 train_time:159614ms step_avg:96.27ms +step:1659/1750 train_time:159719ms step_avg:96.27ms +step:1660/1750 train_time:159820ms step_avg:96.28ms +step:1661/1750 train_time:159921ms step_avg:96.28ms +step:1662/1750 train_time:160021ms step_avg:96.28ms +step:1663/1750 train_time:160122ms step_avg:96.29ms +step:1664/1750 train_time:160222ms step_avg:96.29ms +step:1665/1750 train_time:160323ms step_avg:96.29ms +step:1666/1750 train_time:160425ms step_avg:96.29ms +step:1667/1750 train_time:160525ms step_avg:96.30ms +step:1668/1750 train_time:160628ms step_avg:96.30ms +step:1669/1750 train_time:160729ms step_avg:96.30ms +step:1670/1750 train_time:160830ms step_avg:96.31ms +step:1671/1750 train_time:160930ms step_avg:96.31ms +step:1672/1750 train_time:161031ms step_avg:96.31ms +step:1673/1750 train_time:161130ms step_avg:96.31ms +step:1674/1750 train_time:161232ms step_avg:96.32ms +step:1675/1750 train_time:161334ms step_avg:96.32ms +step:1676/1750 train_time:161434ms step_avg:96.32ms +step:1677/1750 train_time:161535ms step_avg:96.32ms +step:1678/1750 train_time:161635ms step_avg:96.33ms +step:1679/1750 train_time:161736ms step_avg:96.33ms +step:1680/1750 train_time:161836ms step_avg:96.33ms +step:1681/1750 train_time:161938ms step_avg:96.33ms +step:1682/1750 train_time:162042ms step_avg:96.34ms +step:1683/1750 train_time:162142ms step_avg:96.34ms +step:1684/1750 train_time:162244ms step_avg:96.34ms +step:1685/1750 train_time:162343ms step_avg:96.35ms +step:1686/1750 train_time:162443ms step_avg:96.35ms +step:1687/1750 train_time:162543ms step_avg:96.35ms +step:1688/1750 train_time:162644ms step_avg:96.35ms +step:1689/1750 train_time:162744ms step_avg:96.36ms +step:1690/1750 train_time:162846ms step_avg:96.36ms +step:1691/1750 train_time:162948ms step_avg:96.36ms +step:1692/1750 train_time:163048ms step_avg:96.36ms +step:1693/1750 train_time:163149ms step_avg:96.37ms +step:1694/1750 train_time:163250ms step_avg:96.37ms +step:1695/1750 train_time:163353ms step_avg:96.37ms +step:1696/1750 train_time:163453ms step_avg:96.38ms +step:1697/1750 train_time:163556ms step_avg:96.38ms +step:1698/1750 train_time:163656ms step_avg:96.38ms +step:1699/1750 train_time:163757ms step_avg:96.38ms +step:1700/1750 train_time:163858ms step_avg:96.39ms +step:1701/1750 train_time:163958ms step_avg:96.39ms +step:1702/1750 train_time:164063ms step_avg:96.39ms +step:1703/1750 train_time:164164ms step_avg:96.40ms +step:1704/1750 train_time:164266ms step_avg:96.40ms +step:1705/1750 train_time:164366ms step_avg:96.40ms +step:1706/1750 train_time:164466ms step_avg:96.40ms +step:1707/1750 train_time:164569ms step_avg:96.41ms +step:1708/1750 train_time:164672ms step_avg:96.41ms +step:1709/1750 train_time:164774ms step_avg:96.42ms +step:1710/1750 train_time:164875ms step_avg:96.42ms +step:1711/1750 train_time:164976ms step_avg:96.42ms +step:1712/1750 train_time:165077ms step_avg:96.42ms +step:1713/1750 train_time:165178ms step_avg:96.43ms +step:1714/1750 train_time:165277ms step_avg:96.43ms +step:1715/1750 train_time:165381ms step_avg:96.43ms +step:1716/1750 train_time:165484ms step_avg:96.44ms +step:1717/1750 train_time:165585ms step_avg:96.44ms +step:1718/1750 train_time:165686ms step_avg:96.44ms +step:1719/1750 train_time:165790ms step_avg:96.45ms +step:1720/1750 train_time:165891ms step_avg:96.45ms +step:1721/1750 train_time:165993ms step_avg:96.45ms +step:1722/1750 train_time:166094ms step_avg:96.45ms +step:1723/1750 train_time:166195ms step_avg:96.46ms +step:1724/1750 train_time:166297ms step_avg:96.46ms +step:1725/1750 train_time:166398ms step_avg:96.46ms +step:1726/1750 train_time:166500ms step_avg:96.47ms +step:1727/1750 train_time:166601ms step_avg:96.47ms +step:1728/1750 train_time:166705ms step_avg:96.47ms +step:1729/1750 train_time:166809ms step_avg:96.48ms +step:1730/1750 train_time:166908ms step_avg:96.48ms +step:1731/1750 train_time:167009ms step_avg:96.48ms +step:1732/1750 train_time:167110ms step_avg:96.48ms +step:1733/1750 train_time:167212ms step_avg:96.49ms +step:1734/1750 train_time:167314ms step_avg:96.49ms +step:1735/1750 train_time:167415ms step_avg:96.49ms +step:1736/1750 train_time:167516ms step_avg:96.50ms +step:1737/1750 train_time:167617ms step_avg:96.50ms +step:1738/1750 train_time:167718ms step_avg:96.50ms +step:1739/1750 train_time:167820ms step_avg:96.50ms +step:1740/1750 train_time:167921ms step_avg:96.51ms +step:1741/1750 train_time:168026ms step_avg:96.51ms +step:1742/1750 train_time:168126ms step_avg:96.51ms +step:1743/1750 train_time:168226ms step_avg:96.52ms +step:1744/1750 train_time:168327ms step_avg:96.52ms +step:1745/1750 train_time:168427ms step_avg:96.52ms +step:1746/1750 train_time:168528ms step_avg:96.52ms +step:1747/1750 train_time:168630ms step_avg:96.53ms +step:1748/1750 train_time:168733ms step_avg:96.53ms +step:1749/1750 train_time:168834ms step_avg:96.53ms +step:1750/1750 train_time:168935ms step_avg:96.53ms +step:1750/1750 val_loss:3.2780 train_time:169042ms step_avg:96.60ms +peak memory allocated: 34460 MiB reserved: 48914 MiB diff --git a/requirements.txt b/requirements.txt index fe83bb138..d333f1e8f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy tqdm torch huggingface-hub +triton diff --git a/train_gpt.py b/train_gpt.py index 57ccce211..7a429c0de 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -7,7 +7,7 @@ import copy import glob from dataclasses import dataclass -from functools import lru_cache, partial # Added partial for hook registration +from functools import lru_cache from pathlib import Path os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -19,6 +19,8 @@ # use of FlexAttention contributed by @KoszarskyB from torch.nn.attention.flex_attention import BlockMask, flex_attention #torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min +import triton +import triton.language as tl # ----------------------------------------------------------------------------- # Custom operators: FP8 matmul by @YouJiacheng @@ -102,37 +104,287 @@ def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): mm_op.register_autograd(backward, setup_context=setup_context) # ----------------------------------------------------------------------------- -# Muon optimizer +# Triton kernel for symmetric matrix multiplication by @byronxu99 + +def _get_autotune_configs(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": bm, + "BLOCK_SIZE_N": bn, + "BLOCK_SIZE_K": bk, + "GROUP_SIZE_M": 8, + "LOWER_UPPER": 1, + }, + num_stages=stages, + num_warps=warps, + ) + for bm in [64, 128] + for bn in [64, 128, 256] + for bk in [64, 128] + for stages, warps in [(3, 4), (3, 8), (4, 4)] + if bm // bn <= 2 and bn // bm <= 2 + ] + +@triton.jit +def _pid_to_block( + pid, + M, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(M, BLOCK_SIZE_N) + + # Map PID to a single matrix in batch + batch_idx = pid // (num_pid_m * num_pid_n) + pid = pid % (num_pid_m * num_pid_n) + + # Map PID to 2D grid of blocks + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M) + + m_idx = pid_m * BLOCK_SIZE_M + n_idx = pid_n * BLOCK_SIZE_N + return batch_idx, m_idx, n_idx + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_1_kernel( + A_ptr, C_ptr, + M, K, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_1(A: torch.Tensor, out: torch.Tensor): + """ + Launch Triton kernel to compute C = A @ A.T + """ + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert out.size(-2) == M, "Output matrix has incorrect shape" + assert out.size(-1) == M, "Output matrix has incorrect shape" + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 -@torch.compile -def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor: + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_1_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + K=K, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + ) + return out + +@triton.autotune( + configs=_get_autotune_configs(), + key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"], +) +@triton.jit +def ns_line_2_kernel( + A_ptr, C_ptr, + M, + a_stride_b, a_stride_r, a_stride_c, + c_stride_b, c_stride_r, c_stride_c, + alpha, beta, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + LOWER_UPPER: tl.constexpr, +): + # This is mostly duplicated from ns_line_1_kernel, but also loads and adds a block of A + # Performance is slightly slower than ns_line_1_kernel, so we use two separate kernels + pid = tl.program_id(axis=0) + batch_idx, m_idx, n_idx = _pid_to_block( + pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M + ) + + # Skip blocks that don't need to be computed + skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx) + skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx) + if skip_block_below_diag or skip_block_above_diag: + return + + # Index into one matrix of batch + A_ptr += batch_idx * a_stride_b + C_ptr += batch_idx * c_stride_b + + # Create pointer arrays for A and A.T + offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c) + at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Accumulate over blocks of K + for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0) + at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, at, accumulator) + a_ptrs += BLOCK_SIZE_K * a_stride_c + at_ptrs += BLOCK_SIZE_K * a_stride_c + + # Load block of A to add (corresponds to the current block of C) + offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N) + a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c) + a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M) + a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32) + + # Apply alpha and beta + accumulator *= alpha + accumulator += a_add * beta + + out_dtype = C_ptr.dtype.element_ty + output = accumulator.to(out_dtype) + + # Store block of C + offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M) + offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M) + tl.store(c_ptrs, output, mask=c_mask) + + # Store block of C mirrored across the diagonal + c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c) + c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M) + tl.store(c_ptrs_t, output.T, mask=c_mask_t) + +def ns_line_2(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor): """ - Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a - quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose - of minimizing steps, it turns out to be empirically effective to keep increasing the slope at - zero even beyond the point where the iteration no longer converges all the way to one everywhere - on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T - where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model - performance at all relative to UV^T, where USV^T = G is the SVD. + Launch Triton kernel to compute C = alpha * A @ A.T + beta * A """ - assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) - X = G + assert A.ndim == 2 or A.ndim == 3 + M, K = A.shape[-2:] + assert M == K, "Input matrix must be square" + assert out.size(-2) == M + assert out.size(-1) == M + + batch_size = A.size(0) if A.ndim == 3 else 1 + input_batch_stride = A.stride(0) if A.ndim == 3 else 0 + output_batch_stride = out.stride(0) if out.ndim == 3 else 0 + + grid = lambda meta: ( + batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]), + ) + ns_line_2_kernel[grid]( + A_ptr=A, + C_ptr=out, + M=M, + a_stride_b=input_batch_stride, + a_stride_r=A.stride(-2), + a_stride_c=A.stride(-1), + c_stride_b=output_batch_stride, + c_stride_r=out.stride(-2), + c_stride_c=out.stride(-1), + alpha=alpha, + beta=beta, + ) + return out + +@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower +def newton_schulz_triton(G: torch.Tensor): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT # Ensure spectral norm is at most 1 X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7) + + # Allocate buffers + X = X.contiguous() + A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype) + B = torch.empty_like(A) + C = torch.empty_like(X) + + ns_line_3 = torch.baddbmm if X.ndim > 2 else torch.addmm + # Perform the NS iterations - for _ in range(steps): - A = X @ X.mT - B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng - X = a * X + B @ X + for _ in range(5): + ns_line_1(X, out=A) # A = X @ X.mT + ns_line_2(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A + ns_line_3(X, B, X, beta=a, out=C) # C = a * X + B @ X + X, C = C, X # Swap references to avoid unnecessary copies if G.size(-2) > G.size(-1): X = X.mT return X +# ----------------------------------------------------------------------------- +# Muon optimizer + class Muon(torch.optim.Optimizer): """ Muon - MomentUm Orthogonalized by Newton-schulz @@ -166,7 +418,7 @@ def step(self): rank = dist.get_rank() world_size = dist.get_world_size() reduce_scatter_futures: list[torch.Future] = [] - all_reduce_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] for group in self.param_groups: params: list[Tensor] = group["params"] grad = torch.empty_like(params[-1]) @@ -196,11 +448,11 @@ def step(self): p.mul_(1 - eff_weight_decay) momentum_buffer.lerp_(grad, 1 - momentum) grad = grad.lerp_(momentum_buffer, momentum) - v = zeropower_via_newtonschulz5(grad.bfloat16(), 5) + v = newton_schulz_triton(grad) p.add_(other=v, alpha=-eff_lr) idx += 1 - all_reduce_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) - torch.futures.collect_all(all_reduce_futures).wait() + all_gather_futures.append(dist.all_gather(params_pad[base_i:base_i + world_size], params_pad[base_i + rank], async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() class DistAdam(torch.optim.Optimizer): def __init__(self, params, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01): @@ -221,7 +473,7 @@ def step(self): rank = dist.get_rank() world_size = dist.get_world_size() reduce_scatter_futures: list[torch.Future] = [] - all_reduce_futures: list[torch.Future] = [] + all_gather_futures: list[torch.Future] = [] grad_slices = [] for group in self.param_groups: params: list[Tensor] = group["params"] @@ -272,8 +524,8 @@ def step(self): update = exp_avg.div(denom).mul_(step_size) p_slice.add_(other=update, alpha=-1.0) idx += 1 - all_reduce_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) - torch.futures.collect_all(all_reduce_futures).wait() + all_gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()) + torch.futures.collect_all(all_gather_futures).wait() # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the model @@ -328,14 +580,16 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): self.num_heads = num_heads self.head_dim = head_dim hdim = num_heads * head_dim + assert hdim == dim, "num_heads * head_dim must equal model_dim" std = 0.5 * (dim ** -0.5) bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng # https://x.com/hi_tysam/status/1879699187107033311 - self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound)) + self.qkvo_w = nn.Parameter(torch.empty(4, hdim, dim)) + with torch.no_grad(): + self.qkvo_w[:3].uniform_(-bound, bound) # init QKV weights + self.qkvo_w[3].zero_() # init output weights to zero self.rotary = Rotary(head_dim, max_seq_len) - self.c_proj = CastedLinear(hdim, dim) - self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 self.attn_scale = 0.12 @@ -343,7 +597,7 @@ def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: BlockMask): B, T = x.size(0), x.size(1) # batch size, sequence length assert B == 1, "Must use batch size = 1 for FlexAttention" - q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) + q, k, v = F.linear(x, self.qkvo_w[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) q, k = norm(q), norm(k) # QK norm @Grad62304977 q, k = self.rotary(q), self.rotary(k) if ve is not None: @@ -352,21 +606,27 @@ def forward(self, x: Tensor, ve: Tensor | None, lambdas: Tensor, block_mask: Blo v = lambdas[0] * v y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=self.attn_scale).transpose(1, 2) y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side - y = self.c_proj(y) + y = F.linear(y, self.qkvo_w[3].type_as(y)) return y class MLP(nn.Module): def __init__(self, dim: int): super().__init__() hdim = 4 * dim - self.c_fc = CastedLinear(dim, hdim) - self.c_proj = CastedLinear(hdim, dim) - self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 + # make both matrices have the same shape because optimizer sorts params by shape + # 2 matrices x 12 layers = 24 total, which is divisible by 8 GPU world size + self.c_fc = nn.Parameter(torch.empty(dim, hdim)) + self.c_proj = nn.Parameter(torch.empty(dim, hdim)) + std = 0.5 * (dim ** -0.5) + bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng + with torch.no_grad(): + self.c_fc.uniform_(-bound, bound) + self.c_proj.zero_() # zero init suggested by @Grad62304977 def forward(self, x: Tensor): - x = self.c_fc(x) + x = F.linear(x, self.c_fc.T.type_as(x)) x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 - x = self.c_proj(x) + x = F.linear(x, self.c_proj.type_as(x)) return x class Block(nn.Module): @@ -400,7 +660,8 @@ def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy's experiments. - self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448) + use_fp8 = not os.environ.get("DISABLE_FP8", False) + self.lm_head = CastedLinear(model_dim, vocab_size, use_fp8=use_fp8, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448) self.lm_head.weight.detach().zero_() # @Grad62304977 # Add learnable skip connection weights for decoder layers assert num_layers % 2 == 0 @@ -511,43 +772,45 @@ def _load_data_shard(file: Path): return tokens # find world_size starting indicies, such that each begins with token 50256 and local_batches don't overlap -def find_batch_starts(tokens: Tensor, pos: int, local_batch_size: int, max_batch_span: int): - boundary_mask = tokens[pos : pos + max_batch_span] == 50256 +def find_batch_starts(tokens: Tensor, pos: int, seq_len: int, token_window: int): + boundary_mask = tokens[pos : pos + token_window] == 50256 boundary_positions = torch.nonzero(boundary_mask, as_tuple=False).squeeze(-1) + pos start = boundary_positions[0].item() starts = [] for i in range(1, len(boundary_positions)): end = boundary_positions[i].item() - if end - start >= local_batch_size: + if end - start >= seq_len: starts.append(start) # append start once end pos is confirmed if len(starts) == dist.get_world_size(): return starts, end - pos start = end - assert False # increase max_batch_span if necessary + assert False # increase token_window if necessary -def distributed_data_generator(filename_pattern: str, batch_size: int, align_to_bos: bool): +def distributed_data_generator(filename_pattern: str, seq_len: int, grad_accum_steps: int, align_to_bos: bool): rank = dist.get_rank() world_size = dist.get_world_size() + batch_size = seq_len * world_size files = [Path(file) for file in sorted(glob.glob(filename_pattern))] - assert batch_size % world_size == 0 - local_batch_size = batch_size // world_size file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training tokens, pos = _load_data_shard(next(file_iter)), 0 - max_batch_span = 2 * batch_size if align_to_bos else batch_size # provide buffer to handle samples up to length local_batch_size while True: - if pos + max_batch_span + 1 >= len(tokens): - tokens, pos = _load_data_shard(next(file_iter)), 0 - if align_to_bos: - batch_starts, batch_span = find_batch_starts(tokens, pos, local_batch_size, max_batch_span) - start_idx = batch_starts[rank] - else: - batch_span = batch_size - start_idx = pos + rank * local_batch_size - buf = tokens[start_idx:][:local_batch_size + 1] - inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; - targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. - pos += batch_span - yield inputs, targets + token_window = grad_accum_steps * (2 * batch_size if align_to_bos else batch_size) # provide buffer to handle samples up to length seq_len + if pos + token_window + 1 >= len(tokens): + tokens = _load_data_shard(next(file_iter)) + pos = 0 + for _ in range(grad_accum_steps): + if align_to_bos: + batch_starts, tokens_consumed = find_batch_starts(tokens, pos, seq_len, token_window) + start_idx = batch_starts[rank] + else: + tokens_consumed = batch_size + start_idx = pos + rank * seq_len + buf = tokens[start_idx:][:seq_len + 1] + inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side; + targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful. + pos += tokens_consumed + token_window -= tokens_consumed + yield inputs, targets # ----------------------------------------------------------------------------- # int main @@ -568,10 +831,15 @@ class Hyperparameters: save_checkpoint = False args = Hyperparameters() +data_path = os.environ.get("DATA_PATH", ".") +args.train_files = os.path.join(data_path, args.train_files) +args.val_files = os.path.join(data_path, args.val_files) + # torchrun sets these env variables rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) -assert world_size == 8 # this code is designed for 8xH100 +assert 8 % world_size == 0, "world_size must be a divisor of 8" +grad_accum_steps = 8 // world_size assert torch.cuda.is_available() device = torch.device("cuda", int(os.environ["LOCAL_RANK"])) torch.cuda.set_device(device) @@ -599,6 +867,7 @@ def print0(s, console=False): # log information about the hardware/software environment this is running on print0(f"Running Python {sys.version}") print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}") +print0(f"Running Triton version {triton.__version__}") def nvidia_smi(): import subprocess # avoid top level import return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout @@ -650,7 +919,7 @@ def get_window_size_blocks(step: int): window_size = next_multiple_of_n(1728 * x, n=128) return get_window_size_blocks_helper(window_size) -model: nn.Module = torch.compile(model, dynamic=False) +model: nn.Module = torch.compile(model, dynamic=False, fullgraph=True) ######################################## # Warmup kernels # @@ -660,7 +929,7 @@ def get_window_size_blocks(step: int): warmup_steps = 10 initial_state = dict(model=copy.deepcopy(model.state_dict()), optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state -train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, align_to_bos=True) +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) for _ in range(warmup_steps): inputs, targets = next(train_loader) model(inputs, targets, get_window_size_blocks(1)).backward() @@ -676,7 +945,7 @@ def get_window_size_blocks(step: int): # Training and validation # ######################################## -train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, align_to_bos=True) +train_loader = distributed_data_generator(args.train_files, args.train_seq_len, grad_accum_steps, align_to_bos=True) training_time_ms = 0 # start the clock torch.cuda.synchronize() @@ -695,7 +964,7 @@ def get_window_size_blocks(step: int): val_batch_size = world_size * args.val_seq_len assert args.val_tokens % val_batch_size == 0 val_steps = args.val_tokens // val_batch_size - val_loader = distributed_data_generator(args.val_files, val_batch_size, align_to_bos=False) + val_loader = distributed_data_generator(args.val_files, args.val_seq_len, grad_accum_steps, align_to_bos=False) val_loss = 0 with torch.no_grad(): for _ in range(val_steps): @@ -719,8 +988,9 @@ def get_window_size_blocks(step: int): break # --------------- TRAINING SECTION ----------------- - inputs, targets = next(train_loader) - model(inputs, targets, get_window_size_blocks(step)).backward() + for _ in range(grad_accum_steps): + inputs, targets = next(train_loader) + model(inputs, targets, get_window_size_blocks(step)).backward() # set optimization hyperparameters for opt in optimizers: for group in opt.param_groups: @@ -739,4 +1009,4 @@ def get_window_size_blocks(step: int): print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True) -dist.destroy_process_group() \ No newline at end of file +dist.destroy_process_group()