From 00c269c51b3b4299ab60d72e6aa6deacb5eb5479 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Fri, 30 Dec 2022 10:14:34 +0100 Subject: [PATCH 01/24] test(optimizer): no truegrad --- src/model/main.py | 3 +-- src/optimizer.py | 5 +---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/model/main.py b/src/model/main.py index 409ea587..57f89284 100644 --- a/src/model/main.py +++ b/src/model/main.py @@ -65,8 +65,7 @@ def stem(ctx: Context, src: FourArrays) -> FourArrays: return src -def body_ctx(ctx: Context, src: jax.Array) -> Union[ - Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: +def body_ctx(ctx: Context, src: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: src = input_embed(ctx, src) zero = jnp.zeros_like(src) src = stem(ctx, (src, zero, src, zero)) diff --git a/src/optimizer.py b/src/optimizer.py index 52686fbd..2b0c2dfc 100644 --- a/src/optimizer.py +++ b/src/optimizer.py @@ -66,14 +66,11 @@ def graft(param_name: str, magnitude: jax.Array, direction: jax.Array) -> jax.Ar def tg_adam(ctx: Context, param_name: str, grad: jax.Array, tg_grad: jax.Array, step: jax.Array) -> jax.Array: ema_g = ema(ctx, grad, step, 1 - ctx.optimizer.adam_beta1) ema_gsq = ema(ctx, grad ** 2, step, 1 - ctx.optimizer.adam_beta2) - ema_tgsq = ema(ctx, tg_grad, step, 1 - ctx.optimizer.adam_beta3) if ctx.is_initializing: return grad - adam_update = ema_g * stable_rsqrt(ema_gsq, ctx.optimizer.epsilon) - tg_update = ema_g * stable_rsqrt(ema_tgsq, ctx.optimizer.epsilon) - return graft(param_name, adam_update, tg_update) + return ema_g * stable_rsqrt(ema_gsq, ctx.optimizer.epsilon) def get_current_lr(ctx: Context, step: jax.Array) -> jax.Array: From cdc8d0db18aa0fa238bd637e55f8254fa19689ca Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 22 Jan 2023 13:19:31 +0100 Subject: [PATCH 02/24] feat(model): use mish + merge mixer/conv --- README.md | 5 ++- config.yaml | 3 -- src/context.py | 3 -- src/model/activate.py | 10 ++++-- src/model/conv.py | 33 +++++++++++++++++--- src/model/main.py | 2 -- src/model/mixer.py | 59 ----------------------------------- src/model/norm.py | 43 ++++++++++--------------- unittests/consistency/step.py | 2 -- unittests/grad/leak.py | 2 -- unittests/grad/norm.py | 25 +++++++-------- 11 files changed, 67 insertions(+), 120 deletions(-) delete mode 100644 src/model/mixer.py diff --git a/README.md b/README.md index b5861dcc..04445056 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,10 @@ running quickly and with a small memory footprint of O(N^1.5) instead of O(N^2). Therefore, an axial mlp-mixer needs less memory and compute than a standard transformer while providing better performance at scale. +**UPDATE** +This is now merged with the dense convolution by running `cat([x.T, x]) @ W_big` rather than `x @ W` and in a separate +block `x.T @ W_2`. + #### Reversible Most commonly, large transformers use [activation checkpointing](https://arxiv.org/abs/1604.06174v2), which saves the @@ -90,7 +94,6 @@ speedup for both training and inference. ### Optimizer - #### Adaptive Gradient Clipping ## Getting Started diff --git a/config.yaml b/config.yaml index 81e49957..ada90f9e 100644 --- a/config.yaml +++ b/config.yaml @@ -18,8 +18,6 @@ dims: moe_intermediate: 4096 one: 1 outer_bottleneck_kernel: 25 - pointwise_features: 512 - pointwise_kernel: 5 sequence: 4096 vocab: 256 global_prefix: '' @@ -47,7 +45,6 @@ optimizer: momentum_beta: 0.1 norm_scale: 1 output_scale: 1 - pointwise_scale: 1 preconditioning_compute_steps: 128 qrnn_scale: 1 skip_preconditioning_dim_size_gt: 1024 diff --git a/src/context.py b/src/context.py index 71cc0665..3c81bd89 100644 --- a/src/context.py +++ b/src/context.py @@ -75,10 +75,7 @@ class Dims(DataClass): outer_bottleneck_kernel: int = 25 inner_bottleneck_kernel: int = 49 inner_bottleneck_features: int = 128 - pointwise_kernel: int = 5 features: int = 256 - spatial_mixing_kernel: int = 512 - pointwise_features: int = 512 sequence: int = 4096 depth: int = 8 vocab: int = 256 diff --git a/src/model/activate.py b/src/model/activate.py index 01d2f779..0a99633c 100644 --- a/src/model/activate.py +++ b/src/model/activate.py @@ -1,13 +1,17 @@ import jax -from jax import numpy as jnp +from jax import lax +# [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992) showed that +# "complicated" activation functions such as SiLU, GeLU and Mish help increase the rank of the embedding matrix. +# To ensure low-redundancy features, we use a function from that family. def activate_forward(inp: jax.Array) -> jax.Array: - return inp * activate_grad(inp) + return inp * lax.tanh(lax.max(inp, 0) + lax.log1p(lax.exp(-lax.abs(inp)))) # jax.nn.softplus without nan check def activate_grad(inp: jax.Array) -> jax.Array: - return jnp.where(inp < 0, 0.01, 1) + tanh_sp = lax.tanh(lax.max(inp, 0) + lax.log1p(lax.exp(-lax.abs(inp)))) + return tanh_sp + inp * jax.nn.sigmoid() * (1 - lax.square(tanh_sp)) def activate(inp: jax.Array) -> jax.Array: diff --git a/src/model/conv.py b/src/model/conv.py index 1ffc4eca..2d17b90f 100644 --- a/src/model/conv.py +++ b/src/model/conv.py @@ -1,7 +1,9 @@ +import math + import jax from jax import numpy as jnp -from src.backend import conv as lax_conv, get_param, square_grad, with_context +from src.backend import get_param, square_grad, with_context, conv as lax_conv, pattern_match from src.context import Context from src.model.norm import prenorm, scale_norm_act @@ -25,7 +27,28 @@ def _conv(x, y): @prenorm @with_context() -def dense_block(ctx: Context, inp: jax.Array) -> jax.Array: - inp = conv(ctx, inp, ctx.dims.pointwise_kernel, ctx.dims.features, ctx.dims.pointwise_features) - inp = scale_norm_act(ctx, inp, ctx.dims.pointwise_features) - return conv(ctx, inp, ctx.dims.pointwise_kernel, ctx.dims.pointwise_features, ctx.dims.features) +def dense_block(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: + # Following [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992), we're + # not increasing the dimensionality in the middle, as the rank doesn't increase -> no useful features are added. + + original_shape = inp.shape + original_batch, sequence, features = original_shape + max_dims = math.ceil(math.log(sequence, ctx.dims.features)) + + def _get_mix_fn(current_depth: int): + outer_sequence = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1) + + def _fn(x: jax.Array): + out = x.reshape(original_batch, outer_sequence, features, -1) + out = jnp.transpose(out, (0, 1, 3, 2)) + return out.reshape(original_batch, sequence, features) + + return _fn + + pad = jnp.zeros((original_batch, ctx.dims.features - 1, features), inp.dtype) + inp_padded = jnp.concatenate([pad, inp[:, :1 - ctx.dims.features]], 1) + inp = jnp.concatenate([inp, pattern_match(_get_mix_fn, max_dims, depth, inp_padded)], -1) + + inp = conv(ctx, inp, 5, 2 * ctx.dims.features, 2 * ctx.dims.features) + inp = scale_norm_act(ctx, inp, 2 * ctx.dims.features, double=True) + return conv(ctx, inp, 5, 4 * ctx.dims.features, ctx.dims.features) diff --git a/src/model/main.py b/src/model/main.py index 57f89284..1497e7c9 100644 --- a/src/model/main.py +++ b/src/model/main.py @@ -7,7 +7,6 @@ from src.context import Context from src.model.conv import dense_block from src.model.loss import cross_entropy_loss -from src.model.mixer import mix from src.model.moe import dense_moe from src.model.norm import scale_norm_act from src.model.reversible import FourArrays, reversible, revnet_out @@ -40,7 +39,6 @@ def _fn(carry: FourArrays, inp: Tuple[Dict[str, jax.Array], jax.Array]): src = reversible(ctx, dense_block, src) src = reversible(ctx, dense_moe, src) src = reversible(ctx, dense_block, src) - src = reversible(ctx, mix, src, depth) name_cache.update(ctx.name_cache) if ctx.is_initializing: return src diff --git a/src/model/mixer.py b/src/model/mixer.py deleted file mode 100644 index 7d2b42a1..00000000 --- a/src/model/mixer.py +++ /dev/null @@ -1,59 +0,0 @@ -import math -from typing import Sequence - -import jax -from jax import numpy as jnp - -from src.backend import dot, get_param, pattern_match, square_grad, with_context -from src.context import Context -from src.model.norm import prenorm, scale_norm_act - - -def dot_sq(src: jax.Array, weight: jax.Array, weight_sq: jax.Array, - left_contract_dims: Sequence[int], right_contract_dims: Sequence[int]): - def _dot(x, y): - return dot(x, y, left_contract_dims=left_contract_dims, right_contract_dims=right_contract_dims) - - return square_grad(_dot, src, weight, weight_sq) - - -@prenorm -@with_context() -def mix(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: - weight_shape = [ctx.dims.spatial_mixing_kernel] * 2 - run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) - wgt0, wgt0_sq = get_param(ctx, "mix_0", weight_shape, return_sq=True) - wgt1, wgt1_sq = get_param(ctx, "mix_1", weight_shape, return_sq=True) - scale, scale_sq = get_param(ctx, "scale", [ctx.dims.features], std=0, mean=1, dtype=run_type, return_sq=True) - if ctx.is_initializing: - return inp - - original_shape = inp.shape - _batch, sequence, _features = original_shape - max_dims = math.ceil(math.log(sequence, ctx.dims.spatial_mixing_kernel)) - original_batch = inp.shape[0] - if ctx.model.autoregressive: - wgt0 = jnp.triu(wgt0) - wgt1 = jnp.triu(wgt1) - - def _get_mix_fn(current_depth: int): - def _fn(x: jax.Array): - batch = max(sequence // ctx.dims.spatial_mixing_kernel ** (current_depth % max_dims + 1), 1) - out = x.reshape(original_batch * batch, ctx.dims.spatial_mixing_kernel, -1) - inner_batch, inner_sequence, inner_features = out.shape - - # Shape[Batch, Sequence, Features] * Shape[Sequence, Sequence] -> Shape[Batch, Features, Sequence] - out = dot_sq(out, wgt0, wgt0_sq, left_contract_dims=(1,), right_contract_dims=(0,)) - - out = out.reshape(-1, ctx.dims.features, inner_sequence) - out = scale_norm_act(ctx, out, ctx.dims.features, weight=(scale, scale_sq), add_to_prefix=False, dim=1) - out = out.reshape(inner_batch, inner_features, inner_sequence) - - # Shape[Batch, Features, Sequence] * Shape[Sequence, Sequence] -> Shape[Batch, Features, Sequence] - out = dot_sq(out, wgt1, wgt1_sq, left_contract_dims=(2,), right_contract_dims=(0,)) - out = out.transpose(0, 2, 1) - return out.reshape(original_shape) - - return _fn - - return pattern_match(_get_mix_fn, max_dims, depth, inp) diff --git a/src/model/norm.py b/src/model/norm.py index f748800f..624e2db7 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -4,7 +4,6 @@ from jax import lax, numpy as jnp from src.backend import get_param, promote_to, stable_rsqrt, with_context -from src.constants import ParallelAxes from src.context import Context from src.model.activate import activate_forward, activate_grad @@ -19,41 +18,29 @@ def _fn(ctx: Context, inp: jax.Array, *args) -> jax.Array: return _fn -def all_gather(inp: jax.Array, dim: int) -> jax.Array: - @jax.custom_gradient - def _fn(x): - def _grad(dy): - return lax.psum_scatter(dy, axis_name=ParallelAxes.model, scatter_dimension=dim, tiled=True) - - return lax.all_gather(x, axis_name=ParallelAxes.model, axis=dim, tiled=True), _grad - - return _fn(inp) - - -def norm_forward(ctx: Context, src: jax.Array, wgt: Optional[jax.Array] = None, psum: bool = False, - act: bool = True, dim: int = 2): +def norm_forward(ctx: Context, src: jax.Array, wgt: Optional[jax.Array] = None, act: bool = True, dim: int = 2, + double: bool = False): run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) src_fp64 = promote_to(src, run_type) own_sum = lax.square(src_fp64).sum(dim, keepdims=True) - if psum: - own_sum = lax.psum(own_sum, ParallelAxes.model) std = stable_rsqrt(own_sum, ctx.model.norm.eps) out = src_fp64 * std * wgt - if act: + if act and double: + out = jnp.concatenate([activate_forward(out), activate_forward(-out)], dim) + elif act: out = activate_forward(out) out = out.astype(src.dtype) - if psum: - out = all_gather(out, dim) return out, std @with_context() def scale_norm_act(ctx: Context, inp: jax.Array, feature_dim: int, - weight: Union[bool, None, Tuple[jax.Array, jax.Array]] = None, - psum: bool = False, act: bool = True, dim: int = 2) -> jax.Array: + weight: Union[bool, None, Tuple[jax.Array, jax.Array]] = None, act: bool = True, dim: int = 2, + double: bool = False) -> jax.Array: run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) if weight is None: - weight, weight_sq = get_param(ctx, "scale", [feature_dim], std=0, mean=1, dtype=run_type, return_sq=True) + weight, weight_sq = get_param(ctx, "scale", [feature_dim * (1 + double)], std=0, mean=1, dtype=run_type, + return_sq=True) elif weight is False: weight_sq = weight = 1 else: @@ -67,23 +54,25 @@ def _fn(src: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): if isinstance(wgt, jax.Array): wgt = wgt.reshape((1,) * dim + (-1,) + (1,) * (src.ndim - 1 - dim)) - out, std = norm_forward(ctx, src, wgt, psum, act, dim) + out, std = norm_forward(ctx, src, wgt, act, dim, double) def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[jax.Array, None, None]]: - inner_src = lax.all_gather(src, ParallelAxes.model, axis=dim) if psum else src + inner_src = src src_fp64 = promote_to(inner_src, run_type) norm_out = src_fp64 * std dy = promote_to(dy, run_type) if act: - dy = dy * activate_grad(norm_out * wgt) + bw_out = norm_out * wgt + if double: + dy = dy * activate_grad(bw_out) - dy * activate_grad(-bw_out) + else: + dy *= activate_grad(bw_out) d_normed = dy * wgt d_std = (d_normed * src_fp64).sum(dim, keepdims=True) # broadcast forward -> sum backward d_std *= std ** 3 # reciprocal + x^(1/pow) -> 1/std^2 * 1/std^(pow-1) * 1/pow d_std *= src_fp64 # x^pow -> pow * x^(pow-1), multiply fused with above dx = d_normed * std - d_std - if psum: - dx = lax.psum_scatter(dx, axis_name=ParallelAxes.model, scatter_dimension=dim, tiled=True) dx = dx.astype(src.dtype) if not isinstance(wgt, jax.Array): diff --git a/unittests/consistency/step.py b/unittests/consistency/step.py index d561c54d..6f3ff73e 100644 --- a/unittests/consistency/step.py +++ b/unittests/consistency/step.py @@ -15,10 +15,8 @@ def get_wctx(config: typing.Optional[typing.Dict[str, typing.Any]] = None): ctx = wctx.ctx ctx.dims.batch = 16 - ctx.dims.spatial_mixing_kernel = 8 ctx.dims.sequence = 128 ctx.dims.features = 16 - ctx.dims.pointwise_features = 32 ctx.dims.inner_bottleneck_features = 8 return wctx, ctx diff --git a/unittests/grad/leak.py b/unittests/grad/leak.py index 0f2568a3..8b9d5b9a 100644 --- a/unittests/grad/leak.py +++ b/unittests/grad/leak.py @@ -42,8 +42,6 @@ def test(samples: int, depth: int): ctx.dims.depth = depth ctx.dims.features = 8 ctx.dims.inner_bottleneck_features = 4 - ctx.dims.pointwise_features = 16 - ctx.dims.spatial_mixing_kernel = ctx.dims.sequence // 2 src = randn(ctx.dims.batch, ctx.dims.sequence, ctx.dims.features).astype(jnp.bfloat16) def _fn(x: jax.Array): diff --git a/unittests/grad/norm.py b/unittests/grad/norm.py index 9a6c39a4..c4303628 100644 --- a/unittests/grad/norm.py +++ b/unittests/grad/norm.py @@ -1,4 +1,3 @@ -import jax import pytest from jax import numpy as jnp @@ -7,15 +6,15 @@ from unittests.grad.backend import grad_fn, randn_fn, sample_sizes, trials -def general_test(act: bool, psum: bool, samples: int, dim: int): # skipcq: PYL-W0640 +def general_test(act: bool, samples: int, dim: int, double: bool): # skipcq: PYL-W0640 ctx = Context() ctx.is_initializing = False randn = randn_fn() for trial in range(trials): src = randn(int(samples ** 0.5), int(samples ** 0.5), ctx.dims.features) - multiplier = jax.device_count() if psum else 1 out_shape = list(src.shape)[1:] - out_shape[dim] *= multiplier + if double: + out_shape[dim] *= 2 wgt = randn(out_shape[dim]) wgt_sq = randn(out_shape[dim]) dy = randn(*out_shape) @@ -24,8 +23,8 @@ def general_test(act: bool, psum: bool, samples: int, dim: int): # skipcq: PYL- print(trial) shape = (1,) * dim + (-1,) + (1,) * (src.ndim - 2 - dim) - out0 = grad(lambda x: norm_forward(ctx, x[0], x[1].reshape(shape), bool(psum), act, dim)[0]) - out1 = grad(lambda x: scale_norm_act(ctx, x[0], ctx.dims.features, (x[1], x[2]), bool(psum), act, dim)) + out0 = grad(lambda x: norm_forward(ctx, x[0], x[1].reshape(shape), act, dim, double)[0]) + out1 = grad(lambda x: scale_norm_act(ctx, x[0], ctx.dims.features, (x[1], x[2]), act, dim, double)) assert jnp.allclose(out0[0], out1[0]) assert jnp.allclose(out0[1], out1[1]) @@ -34,16 +33,16 @@ def general_test(act: bool, psum: bool, samples: int, dim: int): # skipcq: PYL- @pytest.mark.parametrize("act", [True, False]) @pytest.mark.parametrize("samples", sample_sizes) def test_act(act: bool, samples: int): - general_test(act, False, samples, 2) + general_test(act, samples, 2, False) -@pytest.mark.parametrize("psum", [False, True]) +@pytest.mark.parametrize("dim", [0, 1, 2]) @pytest.mark.parametrize("samples", sample_sizes) -def test_psum(psum: bool, samples: int): - general_test(True, psum, samples, 2) +def test_dim(dim: int, samples: int): + general_test(True, samples, dim, False) -@pytest.mark.parametrize("dim", [0, 1, 2]) +@pytest.mark.parametrize("double", [False, True]) @pytest.mark.parametrize("samples", sample_sizes) -def test_dim(dim: int, samples: int): - general_test(True, False, samples, dim) +def test_double(double: bool, samples: int): + general_test(True, samples, 2, double) From ae6b483fb6f54a515891dcc77f0652f2cd6a6c7c Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 22 Jan 2023 16:25:21 +0100 Subject: [PATCH 03/24] Revert "test(optimizer): no truegrad" This reverts commit 00c269c51b3b4299ab60d72e6aa6deacb5eb5479. --- src/model/main.py | 3 ++- src/optimizer.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/model/main.py b/src/model/main.py index 1497e7c9..782d2a76 100644 --- a/src/model/main.py +++ b/src/model/main.py @@ -63,7 +63,8 @@ def stem(ctx: Context, src: FourArrays) -> FourArrays: return src -def body_ctx(ctx: Context, src: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: +def body_ctx(ctx: Context, src: jax.Array) -> Union[ + Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: src = input_embed(ctx, src) zero = jnp.zeros_like(src) src = stem(ctx, (src, zero, src, zero)) diff --git a/src/optimizer.py b/src/optimizer.py index 2b0c2dfc..52686fbd 100644 --- a/src/optimizer.py +++ b/src/optimizer.py @@ -66,11 +66,14 @@ def graft(param_name: str, magnitude: jax.Array, direction: jax.Array) -> jax.Ar def tg_adam(ctx: Context, param_name: str, grad: jax.Array, tg_grad: jax.Array, step: jax.Array) -> jax.Array: ema_g = ema(ctx, grad, step, 1 - ctx.optimizer.adam_beta1) ema_gsq = ema(ctx, grad ** 2, step, 1 - ctx.optimizer.adam_beta2) + ema_tgsq = ema(ctx, tg_grad, step, 1 - ctx.optimizer.adam_beta3) if ctx.is_initializing: return grad - return ema_g * stable_rsqrt(ema_gsq, ctx.optimizer.epsilon) + adam_update = ema_g * stable_rsqrt(ema_gsq, ctx.optimizer.epsilon) + tg_update = ema_g * stable_rsqrt(ema_tgsq, ctx.optimizer.epsilon) + return graft(param_name, adam_update, tg_update) def get_current_lr(ctx: Context, step: jax.Array) -> jax.Array: From bf68bf6c774ad5a703be470f6ad994cab92ded2a Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 22 Jan 2023 17:23:17 +0100 Subject: [PATCH 04/24] fix(conv): remove leak --- src/model/conv.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/model/conv.py b/src/model/conv.py index 2d17b90f..ea65a41c 100644 --- a/src/model/conv.py +++ b/src/model/conv.py @@ -37,17 +37,19 @@ def dense_block(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: def _get_mix_fn(current_depth: int): outer_sequence = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1) + inner_sequence = sequence // outer_sequence # == dilation + pad_len = (features - 1) * inner_sequence def _fn(x: jax.Array): - out = x.reshape(original_batch, outer_sequence, features, -1) + pad = jnp.zeros((original_batch, pad_len), x.dtype) + x = jnp.concatenate([pad, x.reshape(original_batch, -1)[:, :-pad_len]], 1) + out = x.reshape(original_batch, outer_sequence, features, inner_sequence) out = jnp.transpose(out, (0, 1, 3, 2)) return out.reshape(original_batch, sequence, features) return _fn - pad = jnp.zeros((original_batch, ctx.dims.features - 1, features), inp.dtype) - inp_padded = jnp.concatenate([pad, inp[:, :1 - ctx.dims.features]], 1) - inp = jnp.concatenate([inp, pattern_match(_get_mix_fn, max_dims, depth, inp_padded)], -1) + inp = jnp.concatenate([inp, pattern_match(_get_mix_fn, max_dims, depth, inp)], -1) inp = conv(ctx, inp, 5, 2 * ctx.dims.features, 2 * ctx.dims.features) inp = scale_norm_act(ctx, inp, 2 * ctx.dims.features, double=True) From 99f2d254cdda606dc200e872e81cc18f176f2851 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 22 Jan 2023 17:23:37 +0100 Subject: [PATCH 05/24] fix: lots of minor issue --- src/model/activate.py | 10 +++++++--- src/model/main.py | 4 ++-- src/model/norm.py | 6 +++--- unittests/grad/leak.py | 2 ++ unittests/grad/norm.py | 2 -- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/model/activate.py b/src/model/activate.py index 0a99633c..d70ffde0 100644 --- a/src/model/activate.py +++ b/src/model/activate.py @@ -5,13 +5,17 @@ # [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992) showed that # "complicated" activation functions such as SiLU, GeLU and Mish help increase the rank of the embedding matrix. # To ensure low-redundancy features, we use a function from that family. +def _tanh_sp(inp: jax.Array) -> jax.Array: # jax.nn.softplus without nan check + return lax.tanh(lax.max(inp, 0.) + lax.log1p(lax.exp(-lax.abs(inp)))) + + def activate_forward(inp: jax.Array) -> jax.Array: - return inp * lax.tanh(lax.max(inp, 0) + lax.log1p(lax.exp(-lax.abs(inp)))) # jax.nn.softplus without nan check + return inp * _tanh_sp(inp) def activate_grad(inp: jax.Array) -> jax.Array: - tanh_sp = lax.tanh(lax.max(inp, 0) + lax.log1p(lax.exp(-lax.abs(inp)))) - return tanh_sp + inp * jax.nn.sigmoid() * (1 - lax.square(tanh_sp)) + tanh_sp = _tanh_sp(inp) + return tanh_sp + inp * jax.nn.sigmoid(inp) * (1 - lax.square(tanh_sp)) def activate(inp: jax.Array) -> jax.Array: diff --git a/src/model/main.py b/src/model/main.py index 782d2a76..b830dfb4 100644 --- a/src/model/main.py +++ b/src/model/main.py @@ -36,9 +36,9 @@ def _fn(carry: FourArrays, inp: Tuple[Dict[str, jax.Array], jax.Array]): ctx.parameters.update(shared_params) depth = depth.reshape([]) src = [ctx.parameters] + list(carry) - src = reversible(ctx, dense_block, src) + src = reversible(ctx, dense_block, src, depth) src = reversible(ctx, dense_moe, src) - src = reversible(ctx, dense_block, src) + src = reversible(ctx, dense_block, src, depth) name_cache.update(ctx.name_cache) if ctx.is_initializing: return src diff --git a/src/model/norm.py b/src/model/norm.py index 624e2db7..e2cd99cd 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -39,8 +39,7 @@ def scale_norm_act(ctx: Context, inp: jax.Array, feature_dim: int, double: bool = False) -> jax.Array: run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) if weight is None: - weight, weight_sq = get_param(ctx, "scale", [feature_dim * (1 + double)], std=0, mean=1, dtype=run_type, - return_sq=True) + weight, weight_sq = get_param(ctx, "scale", [feature_dim], std=0, mean=1, dtype=run_type, return_sq=True) elif weight is False: weight_sq = weight = 1 else: @@ -64,7 +63,8 @@ def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[ if act: bw_out = norm_out * wgt if double: - dy = dy * activate_grad(bw_out) - dy * activate_grad(-bw_out) + dy0, dy1 = jnp.split(dy, 2, dim) + dy = dy0 * activate_grad(bw_out) - dy1 * activate_grad(-bw_out) else: dy *= activate_grad(bw_out) d_normed = dy * wgt diff --git a/unittests/grad/leak.py b/unittests/grad/leak.py index 8b9d5b9a..23b970c8 100644 --- a/unittests/grad/leak.py +++ b/unittests/grad/leak.py @@ -7,6 +7,7 @@ from src.constants import ParallelAxes from src.context import Context +from src.main import add_zeros from src.model.main import stem from src.model.reversible import revnet_out from unittests.grad.backend import grad_fn, randn_fn, trials @@ -51,6 +52,7 @@ def _fn(x: jax.Array): return params params = jax.pmap(_fn, ParallelAxes.model)(src) + add_zeros(params) ctx.is_initializing = False def _inner(inp: typing.Tuple[typing.Dict[str, jax.Array], jax.Array]): diff --git a/unittests/grad/norm.py b/unittests/grad/norm.py index c4303628..f7056d00 100644 --- a/unittests/grad/norm.py +++ b/unittests/grad/norm.py @@ -13,8 +13,6 @@ def general_test(act: bool, samples: int, dim: int, double: bool): # skipcq: PY for trial in range(trials): src = randn(int(samples ** 0.5), int(samples ** 0.5), ctx.dims.features) out_shape = list(src.shape)[1:] - if double: - out_shape[dim] *= 2 wgt = randn(out_shape[dim]) wgt_sq = randn(out_shape[dim]) dy = randn(*out_shape) From e28bd98e9ddd14abb35961578301b7f150ce63d8 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 22 Jan 2023 18:07:13 +0100 Subject: [PATCH 06/24] feat(tests): add test script --- src/model/norm.py | 8 ++++---- test.sh | 9 +++++++++ unittests/grad/backend.py | 13 +++++++------ unittests/grad/norm.py | 2 ++ 4 files changed, 22 insertions(+), 10 deletions(-) create mode 100644 test.sh diff --git a/src/model/norm.py b/src/model/norm.py index e2cd99cd..39d11aeb 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -50,14 +50,14 @@ def scale_norm_act(ctx: Context, inp: jax.Array, feature_dim: int, @jax.custom_gradient def _fn(src: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): + weight_shape = wgt.shape if isinstance(wgt, jax.Array): wgt = wgt.reshape((1,) * dim + (-1,) + (1,) * (src.ndim - 1 - dim)) out, std = norm_forward(ctx, src, wgt, act, dim, double) def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[jax.Array, None, None]]: - inner_src = src - src_fp64 = promote_to(inner_src, run_type) + src_fp64 = promote_to(src, run_type) norm_out = src_fp64 * std dy = promote_to(dy, run_type) if act: @@ -81,8 +81,8 @@ def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[ summed = list(range(src.ndim)) del summed[dim] d_wgt = dy * norm_out - d_wgt_sq = (lax.square(d_wgt).sum(summed) * ctx.dims.batch).reshape((-1,)).astype(run_type) - d_wgt = d_wgt.sum(summed).reshape((-1,)).astype(run_type) + d_wgt_sq = (lax.square(d_wgt).sum(summed) * ctx.dims.batch).reshape(weight_shape).astype(run_type) + d_wgt = d_wgt.sum(summed).reshape(weight_shape).astype(run_type) return dx, d_wgt, d_wgt_sq return out, _grad diff --git a/test.sh b/test.sh new file mode 100644 index 00000000..2dd13b96 --- /dev/null +++ b/test.sh @@ -0,0 +1,9 @@ +export JAX_ENABLE_X64=1 # allow fp64 +export JAX_DEFAULT_DTYPE_BITS=64 # ..and enforce it + +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla.proto +export XLA_FLAGS="--xla_force_host_platform_device_count=8" # We don't use TPU-CPU for ML +# export XLA_FLAGS="--xla_step_marker_location=1 $XLA_FLAGS" # 0 = entry; 1 = outer while +export PYTHONPATH=`pwd` + +/usr/bin/env python3 -m pytest "$@" # for example: `bash test.sh unittests/grad/norm.py` diff --git a/unittests/grad/backend.py b/unittests/grad/backend.py index ae1b2759..823d601e 100644 --- a/unittests/grad/backend.py +++ b/unittests/grad/backend.py @@ -14,18 +14,19 @@ def randn_fn(): def _fn(*shape: int): seed = rng.randint(0, 2 ** 30) + + def _gen(x): + return jax.random.normal(jax.random.PRNGKey(x + seed), shape, jnp.float32).astype(jnp.float64) / div + div = (shape[-1] * jax.device_count()) ** 0.25 - fn = jax.pmap( - lambda x: jax.random.normal(jax.random.PRNGKey(x + seed), shape, jnp.float32).astype(jnp.float_) / div) - local_devices = jax.local_device_count() - seeds = jnp.arange(local_devices * jax.process_index(), local_devices * (1 + jax.process_index())) - return fn(seeds) + devices = jax.local_device_count() + return jax.pmap(_gen)(jnp.arange(devices * jax.process_index(), devices * (1 + jax.process_index()))) return _fn def grad_fn(dy: jax.Array, *args): def _fn(fn): - return jax.pmap(jax.grad(lambda x: (fn(x) * dy).sum()), ParallelAxes.model)(args) + return jax.pmap(jax.grad(lambda x, y: (fn(x) * y).sum()), ParallelAxes.model)(args, dy) return _fn diff --git a/unittests/grad/norm.py b/unittests/grad/norm.py index f7056d00..d9d56be4 100644 --- a/unittests/grad/norm.py +++ b/unittests/grad/norm.py @@ -15,6 +15,8 @@ def general_test(act: bool, samples: int, dim: int, double: bool): # skipcq: PY out_shape = list(src.shape)[1:] wgt = randn(out_shape[dim]) wgt_sq = randn(out_shape[dim]) + if double: + out_shape[dim] *= 2 dy = randn(*out_shape) print(dy.shape, src.shape, wgt.shape) grad = grad_fn(dy, src, wgt, wgt_sq) From f5dbac3720a592df963eb48b7a9c3e7e4b589adb Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 22 Jan 2023 18:10:54 +0100 Subject: [PATCH 07/24] style(main): fix type hint indentation --- src/model/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/model/main.py b/src/model/main.py index b830dfb4..12f805ba 100644 --- a/src/model/main.py +++ b/src/model/main.py @@ -63,8 +63,7 @@ def stem(ctx: Context, src: FourArrays) -> FourArrays: return src -def body_ctx(ctx: Context, src: jax.Array) -> Union[ - Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: +def body_ctx(ctx: Context, src: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: src = input_embed(ctx, src) zero = jnp.zeros_like(src) src = stem(ctx, (src, zero, src, zero)) From 6e7130eebb71de386cd8501dc0ccfd43fae1eee0 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Tue, 24 Jan 2023 18:02:20 +0100 Subject: [PATCH 08/24] perf(dense): use X+GLU+Mixer+Shift as input to dense; add more efficient scale_norm_act_conv --- src/backend.py | 5 +- src/model/conv.py | 50 +++++--------------- src/model/dense.py | 44 +++++++++++++++++ src/model/main.py | 2 +- src/model/moe.py | 13 ++---- src/model/norm.py | 114 +++++++++++++++++++++++++++++++-------------- 6 files changed, 143 insertions(+), 85 deletions(-) create mode 100644 src/model/dense.py diff --git a/src/backend.py b/src/backend.py index 8e8a361e..5c349e38 100644 --- a/src/backend.py +++ b/src/backend.py @@ -209,8 +209,9 @@ def loop(fn: Callable, fn_input: Any, steps: int, unroll: int = 1): typevar = TypeVar("typevar") +output = TypeVar("output") -def pattern_match(gen_fn: Callable[[int], Callable[[typevar], jax.Array]], cases: int, - predicate: jax.Array, base: typevar): +def pattern_match(gen_fn: Callable[[int], Callable[[typevar], output]], cases: int, + predicate: jax.Array, base: typevar) -> output: return lax.switch(predicate.astype(jnp.int32) % cases, [gen_fn(i) for i in range(cases)], base) diff --git a/src/model/conv.py b/src/model/conv.py index ea65a41c..9c5bf88c 100644 --- a/src/model/conv.py +++ b/src/model/conv.py @@ -1,21 +1,26 @@ -import math +from typing import Tuple import jax from jax import numpy as jnp -from src.backend import get_param, square_grad, with_context, conv as lax_conv, pattern_match +from src.backend import get_param, square_grad, with_context, conv as lax_conv from src.context import Context -from src.model.norm import prenorm, scale_norm_act -@with_context() -def conv(ctx: Context, inp: jax.Array, conv_kernel: int, in_features: int, out_features: int, tied: bool = False): +@with_context +def conv_weight(ctx: Context, conv_kernel: int, in_features: int, out_features: int, tied: bool = False) -> Tuple[ + jax.Array, jax.Array]: fan_in = jnp.arange(conv_kernel, 0, -1, dtype=ctx.model.storage_dtype) fan_in = (1 - 1 / (conv_kernel * ctx.model.conv_scale + ctx.model.conv_shift)) ** fan_in fan_in = fan_in / fan_in.sum() fan_in = fan_in.reshape(1, 1, -1) - weight, weight_sq = get_param(ctx, "conv_weight", [out_features, conv_kernel, in_features], column_axes=2, - lr_scale=fan_in, tied=tied, return_sq=True) + return get_param(ctx, "conv_weight", [out_features, conv_kernel, in_features], column_axes=2, lr_scale=fan_in, + tied=tied, return_sq=True) + + +@with_context() +def conv(ctx: Context, inp: jax.Array, conv_kernel: int, in_features: int, out_features: int, tied: bool = False): + weight, weight_sq = conv_weight(ctx, conv_kernel, in_features, out_features, tied) if ctx.is_initializing: return jnp.zeros(inp.shape[:-1] + (out_features,), dtype=inp.dtype) @@ -23,34 +28,3 @@ def _conv(x, y): return lax_conv(x, y, [(conv_kernel - 1, 0)], 1) return square_grad(_conv, inp, weight, weight_sq) - - -@prenorm -@with_context() -def dense_block(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: - # Following [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992), we're - # not increasing the dimensionality in the middle, as the rank doesn't increase -> no useful features are added. - - original_shape = inp.shape - original_batch, sequence, features = original_shape - max_dims = math.ceil(math.log(sequence, ctx.dims.features)) - - def _get_mix_fn(current_depth: int): - outer_sequence = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1) - inner_sequence = sequence // outer_sequence # == dilation - pad_len = (features - 1) * inner_sequence - - def _fn(x: jax.Array): - pad = jnp.zeros((original_batch, pad_len), x.dtype) - x = jnp.concatenate([pad, x.reshape(original_batch, -1)[:, :-pad_len]], 1) - out = x.reshape(original_batch, outer_sequence, features, inner_sequence) - out = jnp.transpose(out, (0, 1, 3, 2)) - return out.reshape(original_batch, sequence, features) - - return _fn - - inp = jnp.concatenate([inp, pattern_match(_get_mix_fn, max_dims, depth, inp)], -1) - - inp = conv(ctx, inp, 5, 2 * ctx.dims.features, 2 * ctx.dims.features) - inp = scale_norm_act(ctx, inp, 2 * ctx.dims.features, double=True) - return conv(ctx, inp, 5, 4 * ctx.dims.features, ctx.dims.features) diff --git a/src/model/dense.py b/src/model/dense.py new file mode 100644 index 00000000..2cc5f6f8 --- /dev/null +++ b/src/model/dense.py @@ -0,0 +1,44 @@ +import math + +import jax +from jax import numpy as jnp, lax + +from src.backend import with_context, pattern_match +from src.context import Context +from src.model.conv import conv +from src.model.norm import prenorm, scale_norm_act_conv + + +@prenorm +@with_context() +def dense_block(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: + # Following [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992), we're + # not increasing the dimensionality in the middle, as the rank doesn't increase -> no useful features are added. + + original_shape = inp.shape + original_batch, sequence, features = original_shape + max_dims = math.ceil(math.log(sequence, ctx.dims.features)) + + arange = jnp.arange(features) + mask = arange.reshape(1, -1, 1, 1) >= arange.reshape(1, 1, 1, -1) + + def _get_mix_fn(current_depth: int): + def _fn(x: jax.Array): + outer_seq = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1) + inner_seq = sequence // outer_seq # == dilation + inner = lax.broadcast_in_dim(mask, (outer_seq, features, inner_seq // features, features), (0, 1, 2, 3)) + + out = x.reshape(original_batch, outer_seq, features, inner_seq) + out = jnp.transpose(out, (0, 1, 3, 2)) + out = out.reshape(original_batch, sequence, features) + padded = lax.pad(out[:, :-features * inner_seq], 0., ((0, 0, 0), (features * inner_seq, 0, 0), (0, 0, 0))) + return out * inner.reshape(1, sequence, features), padded + + return _fn + + masked, padded = pattern_match(_get_mix_fn, max_dims, depth, inp) + inp_glu = inp * lax.pad(inp[:, :-1], 1., ((0, 0, 0), (1, 0, 0), (0, 0, 0))) + inp = jnp.concatenate([inp, inp_glu, masked, padded], -1) + + inp = conv(ctx, inp, 5, 4 * ctx.dims.features, 4 * ctx.dims.features) + return scale_norm_act_conv(ctx, inp, 5, 4 * ctx.dims.features, ctx.dims.features) diff --git a/src/model/main.py b/src/model/main.py index 12f805ba..a2d51b47 100644 --- a/src/model/main.py +++ b/src/model/main.py @@ -5,7 +5,7 @@ from src.backend import get_param, is_model, is_stacked, square_grad, with_context from src.context import Context -from src.model.conv import dense_block +from src.model.dense import dense_block from src.model.loss import cross_entropy_loss from src.model.moe import dense_moe from src.model.norm import scale_norm_act diff --git a/src/model/moe.py b/src/model/moe.py index d70ab46b..9073b90d 100644 --- a/src/model/moe.py +++ b/src/model/moe.py @@ -4,8 +4,7 @@ from src.backend import with_context from src.constants import ParallelAxes from src.context import Context -from src.model.conv import conv -from src.model.norm import prenorm, scale_norm_act +from src.model.norm import scale_norm_act, scale_norm_act_conv def all_to_all(ctx: Context, x: jax.Array, split_axis: int, concat_axis: int) -> jax.Array: @@ -22,7 +21,6 @@ def _grad(dy: jax.Array) -> jax.Array: return _fn(x) -@prenorm @with_context() def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: devices = jax.device_count() @@ -30,7 +28,7 @@ def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: batch, sequence, features = inp.shape sequence_slice = sequence // devices - inp = conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, features, ctx.dims.inner_bottleneck_features) + inp = scale_norm_act_conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, features, ctx.dims.inner_bottleneck_features) # [Batch, Sequence, Features] -> [Batch, SequenceSlice, Features * Devices] # In essence, 1) Collect features from all devices + 2) Drop unused sequence elements @@ -40,9 +38,7 @@ def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: inp = inp.reshape(batch, sequence_slice, big_params) # Devices^2 more parameters than normal bottleneck block but only Devices-times more flops due to sparsity above - inp = scale_norm_act(ctx, inp, big_params) - inp = conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, big_params, big_params, tied=True) - inp = scale_norm_act(ctx, inp, big_params) + inp = scale_norm_act_conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, big_params, big_params, tied=True) # [Batch, SequenceSlice, Features * Devices] -> [Batch, Sequence, Features] (PixelShuffle across devices) if not ctx.is_initializing: @@ -50,4 +46,5 @@ def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: inp = all_to_all(ctx, inp, 3, 2) inp = inp.reshape(batch, sequence, ctx.dims.inner_bottleneck_features) - return conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, ctx.dims.inner_bottleneck_features, features) + out = scale_norm_act_conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, ctx.dims.inner_bottleneck_features, features) + return scale_norm_act(ctx, out, features, act=False) diff --git a/src/model/norm.py b/src/model/norm.py index 39d11aeb..3e28b046 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -1,11 +1,12 @@ -from typing import Tuple, Optional, Union, Callable +from typing import Tuple, Optional, Union, Callable, List import jax from jax import lax, numpy as jnp -from src.backend import get_param, promote_to, stable_rsqrt, with_context +from src.backend import get_param, promote_to, stable_rsqrt, with_context, conv as lax_conv from src.context import Context from src.model.activate import activate_forward, activate_grad +from src.model.conv import conv_weight def prenorm(fn: Callable[[Context, jax.Array], jax.Array]): @@ -19,18 +20,53 @@ def _fn(ctx: Context, inp: jax.Array, *args) -> jax.Array: def norm_forward(ctx: Context, src: jax.Array, wgt: Optional[jax.Array] = None, act: bool = True, dim: int = 2, - double: bool = False): + double: bool = False, std: Optional[jax.Array] = None): run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) src_fp64 = promote_to(src, run_type) - own_sum = lax.square(src_fp64).sum(dim, keepdims=True) - std = stable_rsqrt(own_sum, ctx.model.norm.eps) - out = src_fp64 * std * wgt + if std is None: + own_sum = lax.square(src_fp64).sum(dim, keepdims=True) + std = stable_rsqrt(own_sum, ctx.model.norm.eps) + norm_out = src_fp64 * std + out = multiplied = norm_out * wgt if act and double: out = jnp.concatenate([activate_forward(out), activate_forward(-out)], dim) elif act: out = activate_forward(out) out = out.astype(src.dtype) - return out, std + return out, norm_out, multiplied, src_fp64, std + + +def norm_backward(ctx: Context, src: jax.Array, wgt: jax.Array, std: jax.Array, dy: jax.Array, act: bool, + dim: int, double: bool, weight_shape: List[int], run_type: jnp.dtype, + src_fp64: Optional[jax.Array] = None, norm_out: Optional[jax.Array] = None, + bw_out: Optional[jax.Array] = None): + src_fp64 = promote_to(src, run_type) if src_fp64 is None else src_fp64 + norm_out = (src_fp64 * std) if norm_out is None else norm_out + dy = promote_to(dy, run_type) + if act: + bw_out = (norm_out * wgt) if bw_out is None else bw_out + if double: + dy0, dy1 = jnp.split(dy, 2, dim) + dy = dy0 * activate_grad(bw_out) - dy1 * activate_grad(-bw_out) + else: + dy *= activate_grad(bw_out) + d_normed = dy * wgt + + d_std = (d_normed * src_fp64).sum(dim, keepdims=True) # broadcast forward -> sum backward + d_std *= std ** 3 # reciprocal + x^(1/pow) -> 1/std^2 * 1/std^(pow-1) * 1/pow + d_std *= src_fp64 # x^pow -> pow * x^(pow-1), multiply fused with above + dx = d_normed * std - d_std + dx = dx.astype(src.dtype) + + if not isinstance(wgt, jax.Array): + return dx, None, None + + summed = list(range(src.ndim)) + del summed[dim] + d_wgt = dy * norm_out + d_wgt_sq = (lax.square(d_wgt).sum(summed) * ctx.dims.batch).reshape(weight_shape).astype(run_type) + d_wgt = d_wgt.sum(summed).reshape(weight_shape).astype(run_type) + return dx, d_wgt, d_wgt_sq @with_context() @@ -50,41 +86,47 @@ def scale_norm_act(ctx: Context, inp: jax.Array, feature_dim: int, @jax.custom_gradient def _fn(src: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): - weight_shape = wgt.shape if isinstance(wgt, jax.Array): wgt = wgt.reshape((1,) * dim + (-1,) + (1,) * (src.ndim - 1 - dim)) - out, std = norm_forward(ctx, src, wgt, act, dim, double) + out, _, _, _, std = norm_forward(ctx, src, wgt, act, dim, double) def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[jax.Array, None, None]]: - src_fp64 = promote_to(src, run_type) - norm_out = src_fp64 * std - dy = promote_to(dy, run_type) - if act: - bw_out = norm_out * wgt - if double: - dy0, dy1 = jnp.split(dy, 2, dim) - dy = dy0 * activate_grad(bw_out) - dy1 * activate_grad(-bw_out) - else: - dy *= activate_grad(bw_out) - d_normed = dy * wgt - - d_std = (d_normed * src_fp64).sum(dim, keepdims=True) # broadcast forward -> sum backward - d_std *= std ** 3 # reciprocal + x^(1/pow) -> 1/std^2 * 1/std^(pow-1) * 1/pow - d_std *= src_fp64 # x^pow -> pow * x^(pow-1), multiply fused with above - dx = d_normed * std - d_std - dx = dx.astype(src.dtype) - - if not isinstance(wgt, jax.Array): - return dx, None, None - - summed = list(range(src.ndim)) - del summed[dim] - d_wgt = dy * norm_out - d_wgt_sq = (lax.square(d_wgt).sum(summed) * ctx.dims.batch).reshape(weight_shape).astype(run_type) - d_wgt = d_wgt.sum(summed).reshape(weight_shape).astype(run_type) - return dx, d_wgt, d_wgt_sq + return norm_backward(ctx, src, wgt, std, dy, act, dim, double, _wgt_dummy.shape, run_type) return out, _grad return _fn(inp, weight, weight_sq) + + +@with_context +def scale_norm_act_conv(ctx: Context, inp: jax.Array, kernel: int, in_features: int, out_features: int, + tied: bool = False) -> jax.Array: + run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) + scale, scale_sq = get_param(ctx, "scale", [in_features], std=0, mean=1, dtype=run_type, return_sq=True) + weight, weight_sq = conv_weight(tied=tied) + + if ctx.is_initializing: + return inp + + dim = inp.ndim - 1 + + def _conv(x, y): + return lax_conv(x, y, [(kernel - 1, 0)], 1) + + @jax.custom_gradient + def _fn(src: jax.Array, scl: jax.Array, _scl_dummy: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): + scl = scl.reshape((1,) * dim + (-1,)) + out, _, _, _, std = norm_forward(ctx, src, scl, True, dim, False) + + def _grad(dy: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + out2, norm_out, bw_out, src_fp64, _ = norm_forward(ctx, src, scl, True, dim, False, std) + _, d_wgt_sq = jax.vjp(_conv, lax.square(out2), wgt)[1](lax.square(dy)) + dy, d_wgt = jax.vjp(_conv, out2, wgt)[1](dy) + dx, d_scl, d_scl_sq = norm_backward(ctx, src, scl, std, dy, True, dim, False, _wgt_dummy.shape, run_type, + src_fp64, norm_out, bw_out) + return dx, d_scl, d_scl_sq, d_wgt, d_wgt_sq + + return lax_conv(out, wgt, [(kernel - 1, 0)], 1), _grad + + return _fn(inp, scale, scale_sq, weight, weight_sq) From 0a1ea7c15a89b4bada2f269db17aaf5783d02462 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 25 Jan 2023 15:32:29 +0100 Subject: [PATCH 09/24] style: fix deepsource issues --- src/model/conv.py | 4 ++-- src/model/norm.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/model/conv.py b/src/model/conv.py index 9c5bf88c..22e71125 100644 --- a/src/model/conv.py +++ b/src/model/conv.py @@ -8,8 +8,8 @@ @with_context -def conv_weight(ctx: Context, conv_kernel: int, in_features: int, out_features: int, tied: bool = False) -> Tuple[ - jax.Array, jax.Array]: +def conv_weight(ctx: Context, conv_kernel: int, in_features: int, out_features: int, tied: bool = False + ) -> Tuple[jax.Array, jax.Array]: fan_in = jnp.arange(conv_kernel, 0, -1, dtype=ctx.model.storage_dtype) fan_in = (1 - 1 / (conv_kernel * ctx.model.conv_scale + ctx.model.conv_shift)) ** fan_in fan_in = fan_in / fan_in.sum() diff --git a/src/model/norm.py b/src/model/norm.py index 3e28b046..bae9616d 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -47,7 +47,7 @@ def norm_backward(ctx: Context, src: jax.Array, wgt: jax.Array, std: jax.Array, bw_out = (norm_out * wgt) if bw_out is None else bw_out if double: dy0, dy1 = jnp.split(dy, 2, dim) - dy = dy0 * activate_grad(bw_out) - dy1 * activate_grad(-bw_out) + dy = dy0 * activate_grad(bw_out) - dy1 * activate_grad(-bw_out) # skipcq: PYL-E1130 else: dy *= activate_grad(bw_out) d_normed = dy * wgt @@ -104,7 +104,7 @@ def scale_norm_act_conv(ctx: Context, inp: jax.Array, kernel: int, in_features: tied: bool = False) -> jax.Array: run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) scale, scale_sq = get_param(ctx, "scale", [in_features], std=0, mean=1, dtype=run_type, return_sq=True) - weight, weight_sq = conv_weight(tied=tied) + weight, weight_sq = conv_weight(ctx, kernel, in_features, out_features, tied) if ctx.is_initializing: return inp From 504153fb61b21e908ff6799d27987d2a9851e0dc Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 10:28:54 +0100 Subject: [PATCH 10/24] fix(dense): use correct lax.pad dtype --- src/model/dense.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/model/dense.py b/src/model/dense.py index 2cc5f6f8..26d4bb2b 100644 --- a/src/model/dense.py +++ b/src/model/dense.py @@ -31,13 +31,14 @@ def _fn(x: jax.Array): out = x.reshape(original_batch, outer_seq, features, inner_seq) out = jnp.transpose(out, (0, 1, 3, 2)) out = out.reshape(original_batch, sequence, features) - padded = lax.pad(out[:, :-features * inner_seq], 0., ((0, 0, 0), (features * inner_seq, 0, 0), (0, 0, 0))) + padded = lax.pad(out[:, :-features * inner_seq], jnp.zeros((), dtype=inp.dtype), + ((0, 0, 0), (features * inner_seq, 0, 0), (0, 0, 0))) return out * inner.reshape(1, sequence, features), padded return _fn masked, padded = pattern_match(_get_mix_fn, max_dims, depth, inp) - inp_glu = inp * lax.pad(inp[:, :-1], 1., ((0, 0, 0), (1, 0, 0), (0, 0, 0))) + inp_glu = inp * lax.pad(inp[:, :-1], jnp.ones((), dtype=inp.dtype), ((0, 0, 0), (1, 0, 0), (0, 0, 0))) inp = jnp.concatenate([inp, inp_glu, masked, padded], -1) inp = conv(ctx, inp, 5, 4 * ctx.dims.features, 4 * ctx.dims.features) From a12d843f1da8f3ac75c7f8ffe34f70ba19551ffe Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 10:40:48 +0100 Subject: [PATCH 11/24] fix(dense): max path doesn't work with pad --- src/model/dense.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model/dense.py b/src/model/dense.py index 26d4bb2b..854197a8 100644 --- a/src/model/dense.py +++ b/src/model/dense.py @@ -17,14 +17,14 @@ def dense_block(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: original_shape = inp.shape original_batch, sequence, features = original_shape - max_dims = math.ceil(math.log(sequence, ctx.dims.features)) + max_dims = math.floor(math.log(sequence, ctx.dims.features)) arange = jnp.arange(features) mask = arange.reshape(1, -1, 1, 1) >= arange.reshape(1, 1, 1, -1) def _get_mix_fn(current_depth: int): def _fn(x: jax.Array): - outer_seq = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1) + outer_seq = sequence // ctx.dims.features ** (current_depth % max_dims + 1) inner_seq = sequence // outer_seq # == dilation inner = lax.broadcast_in_dim(mask, (outer_seq, features, inner_seq // features, features), (0, 1, 2, 3)) From 81b153a97097a25a445da0da842eded2dd47eb3c Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 10:47:01 +0100 Subject: [PATCH 12/24] fix(dense): *features was for numel, not idx in dim --- src/model/dense.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/model/dense.py b/src/model/dense.py index 854197a8..8b6c8bdc 100644 --- a/src/model/dense.py +++ b/src/model/dense.py @@ -17,22 +17,22 @@ def dense_block(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: original_shape = inp.shape original_batch, sequence, features = original_shape - max_dims = math.floor(math.log(sequence, ctx.dims.features)) + max_dims = math.ceil(math.log(sequence, ctx.dims.features)) arange = jnp.arange(features) mask = arange.reshape(1, -1, 1, 1) >= arange.reshape(1, 1, 1, -1) def _get_mix_fn(current_depth: int): def _fn(x: jax.Array): - outer_seq = sequence // ctx.dims.features ** (current_depth % max_dims + 1) + outer_seq = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1) inner_seq = sequence // outer_seq # == dilation inner = lax.broadcast_in_dim(mask, (outer_seq, features, inner_seq // features, features), (0, 1, 2, 3)) out = x.reshape(original_batch, outer_seq, features, inner_seq) out = jnp.transpose(out, (0, 1, 3, 2)) out = out.reshape(original_batch, sequence, features) - padded = lax.pad(out[:, :-features * inner_seq], jnp.zeros((), dtype=inp.dtype), - ((0, 0, 0), (features * inner_seq, 0, 0), (0, 0, 0))) + padded = lax.pad(out[:, :-inner_seq], jnp.zeros((), dtype=inp.dtype), + ((0, 0, 0), (inner_seq, 0, 0), (0, 0, 0))) return out * inner.reshape(1, sequence, features), padded return _fn From 7907a3dfe212b9f1f9c1b81a65d3b52dff5b1ae5 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 10:50:45 +0100 Subject: [PATCH 13/24] fix: add missing () to with_context --- src/model/conv.py | 2 +- src/model/norm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model/conv.py b/src/model/conv.py index 22e71125..983f74a8 100644 --- a/src/model/conv.py +++ b/src/model/conv.py @@ -7,7 +7,7 @@ from src.context import Context -@with_context +@with_context() def conv_weight(ctx: Context, conv_kernel: int, in_features: int, out_features: int, tied: bool = False ) -> Tuple[jax.Array, jax.Array]: fan_in = jnp.arange(conv_kernel, 0, -1, dtype=ctx.model.storage_dtype) diff --git a/src/model/norm.py b/src/model/norm.py index bae9616d..ce567009 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -99,7 +99,7 @@ def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[ return _fn(inp, weight, weight_sq) -@with_context +@with_context() def scale_norm_act_conv(ctx: Context, inp: jax.Array, kernel: int, in_features: int, out_features: int, tied: bool = False) -> jax.Array: run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) From 9292d83850aaba330d4b215b17838c93986decf8 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 11:04:11 +0100 Subject: [PATCH 14/24] fix(dense): handle unreshapable mask case --- src/model/dense.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/model/dense.py b/src/model/dense.py index 8b6c8bdc..f7819634 100644 --- a/src/model/dense.py +++ b/src/model/dense.py @@ -26,14 +26,15 @@ def _get_mix_fn(current_depth: int): def _fn(x: jax.Array): outer_seq = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1) inner_seq = sequence // outer_seq # == dilation - inner = lax.broadcast_in_dim(mask, (outer_seq, features, inner_seq // features, features), (0, 1, 2, 3)) + inner = lax.broadcast_in_dim(mask, (outer_seq, features, inner_seq, features), (0, 1, 2, 3)) + inner = inner.reshape(1, -1, features)[:, :sequence] out = x.reshape(original_batch, outer_seq, features, inner_seq) out = jnp.transpose(out, (0, 1, 3, 2)) out = out.reshape(original_batch, sequence, features) padded = lax.pad(out[:, :-inner_seq], jnp.zeros((), dtype=inp.dtype), ((0, 0, 0), (inner_seq, 0, 0), (0, 0, 0))) - return out * inner.reshape(1, sequence, features), padded + return out * inner, padded return _fn From 4b8eb08b93c985830b211001d206ad51ea912d79 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 11:10:58 +0100 Subject: [PATCH 15/24] fix(activate): casting issue --- src/model/activate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model/activate.py b/src/model/activate.py index d70ffde0..52b23e50 100644 --- a/src/model/activate.py +++ b/src/model/activate.py @@ -1,12 +1,12 @@ import jax -from jax import lax +from jax import lax, numpy as jnp # [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992) showed that # "complicated" activation functions such as SiLU, GeLU and Mish help increase the rank of the embedding matrix. # To ensure low-redundancy features, we use a function from that family. def _tanh_sp(inp: jax.Array) -> jax.Array: # jax.nn.softplus without nan check - return lax.tanh(lax.max(inp, 0.) + lax.log1p(lax.exp(-lax.abs(inp)))) + return lax.tanh(jnp.maximum(inp, 0.) + lax.log1p(lax.exp(-lax.abs(inp)))) def activate_forward(inp: jax.Array) -> jax.Array: From 99ddce37692183c1b4ed2e6d0883204f0f8f1d7d Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 11:19:36 +0100 Subject: [PATCH 16/24] fix(norm): normalize out_features --- src/model/norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/norm.py b/src/model/norm.py index ce567009..89b9791b 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -103,7 +103,7 @@ def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[ def scale_norm_act_conv(ctx: Context, inp: jax.Array, kernel: int, in_features: int, out_features: int, tied: bool = False) -> jax.Array: run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) - scale, scale_sq = get_param(ctx, "scale", [in_features], std=0, mean=1, dtype=run_type, return_sq=True) + scale, scale_sq = get_param(ctx, "scale", [out_features], std=0, mean=1, dtype=run_type, return_sq=True) weight, weight_sq = conv_weight(ctx, kernel, in_features, out_features, tied) if ctx.is_initializing: From b6da2499fea7f9900df8b779ef4d3a24e4025e7e Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 11:29:16 +0100 Subject: [PATCH 17/24] fix(norm): revert --- src/model/norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/norm.py b/src/model/norm.py index 89b9791b..ce567009 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -103,7 +103,7 @@ def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[ def scale_norm_act_conv(ctx: Context, inp: jax.Array, kernel: int, in_features: int, out_features: int, tied: bool = False) -> jax.Array: run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) - scale, scale_sq = get_param(ctx, "scale", [out_features], std=0, mean=1, dtype=run_type, return_sq=True) + scale, scale_sq = get_param(ctx, "scale", [in_features], std=0, mean=1, dtype=run_type, return_sq=True) weight, weight_sq = conv_weight(ctx, kernel, in_features, out_features, tied) if ctx.is_initializing: From ee000be9d138618073729a8074d9e36b9ae3040a Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 11:51:15 +0100 Subject: [PATCH 18/24] test(norm): show shapes --- src/model/norm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/model/norm.py b/src/model/norm.py index ce567009..6fc9579b 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -117,6 +117,7 @@ def _conv(x, y): @jax.custom_gradient def _fn(src: jax.Array, scl: jax.Array, _scl_dummy: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): scl = scl.reshape((1,) * dim + (-1,)) + print(src.shape, scale.shape, scl.shape, in_features, out_features) out, _, _, _, std = norm_forward(ctx, src, scl, True, dim, False) def _grad(dy: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: From 553fd428878be0db519614be534575f74e6184c5 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 12:01:06 +0100 Subject: [PATCH 19/24] test(norm): show name cache and prefix --- src/model/norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/norm.py b/src/model/norm.py index 6fc9579b..ac117243 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -117,7 +117,7 @@ def _conv(x, y): @jax.custom_gradient def _fn(src: jax.Array, scl: jax.Array, _scl_dummy: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): scl = scl.reshape((1,) * dim + (-1,)) - print(src.shape, scale.shape, scl.shape, in_features, out_features) + print(ctx.name_cache, ctx.global_prefix, src.shape, scale.shape, scl.shape, in_features, out_features) out, _, _, _, std = norm_forward(ctx, src, scl, True, dim, False) def _grad(dy: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: From 289a9f105e34932c9db6a1847944d3824aacb482 Mon Sep 17 00:00:00 2001 From: ClashLuke <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 12:19:10 +0100 Subject: [PATCH 20/24] fix(norm): change shape in norm conv --- src/model/norm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/model/norm.py b/src/model/norm.py index ac117243..a9f26bb7 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -107,7 +107,7 @@ def scale_norm_act_conv(ctx: Context, inp: jax.Array, kernel: int, in_features: weight, weight_sq = conv_weight(ctx, kernel, in_features, out_features, tied) if ctx.is_initializing: - return inp + return jnp.zeros(inp.shape[:-1] + (out_features,), dtype=inp.dtype) dim = inp.ndim - 1 @@ -117,7 +117,6 @@ def _conv(x, y): @jax.custom_gradient def _fn(src: jax.Array, scl: jax.Array, _scl_dummy: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): scl = scl.reshape((1,) * dim + (-1,)) - print(ctx.name_cache, ctx.global_prefix, src.shape, scale.shape, scl.shape, in_features, out_features) out, _, _, _, std = norm_forward(ctx, src, scl, True, dim, False) def _grad(dy: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: From 2b728d2a979712274520fb20ba0379ee7025ccba Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 12:35:40 +0100 Subject: [PATCH 21/24] fix(norm): use correct shape --- src/model/norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model/norm.py b/src/model/norm.py index a9f26bb7..b636ecf9 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -123,7 +123,7 @@ def _grad(dy: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, ja out2, norm_out, bw_out, src_fp64, _ = norm_forward(ctx, src, scl, True, dim, False, std) _, d_wgt_sq = jax.vjp(_conv, lax.square(out2), wgt)[1](lax.square(dy)) dy, d_wgt = jax.vjp(_conv, out2, wgt)[1](dy) - dx, d_scl, d_scl_sq = norm_backward(ctx, src, scl, std, dy, True, dim, False, _wgt_dummy.shape, run_type, + dx, d_scl, d_scl_sq = norm_backward(ctx, src, scl, std, dy, True, dim, False, _scl_dummy.shape, run_type, src_fp64, norm_out, bw_out) return dx, d_scl, d_scl_sq, d_wgt, d_wgt_sq From ea36c4212494330b550f57d7d696ccc361e9820c Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 13:06:11 +0100 Subject: [PATCH 22/24] fix(norm): only get shape if array --- src/model/norm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model/norm.py b/src/model/norm.py index b636ecf9..f81a955e 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -92,7 +92,8 @@ def _fn(src: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): out, _, _, _, std = norm_forward(ctx, src, wgt, act, dim, double) def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[jax.Array, None, None]]: - return norm_backward(ctx, src, wgt, std, dy, act, dim, double, _wgt_dummy.shape, run_type) + shp = _wgt_dummy.shape if isinstance(weight, jax.Array) else () + return norm_backward(ctx, src, wgt, std, dy, act, dim, double, shp, run_type) return out, _grad From d57a657fd54379416a681a24eb4f72da4688bb22 Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Wed, 8 Feb 2023 17:23:37 +0100 Subject: [PATCH 23/24] v4 --- src/utils/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/checkpoint.py b/src/utils/checkpoint.py index 58e51033..36362b6d 100644 --- a/src/utils/checkpoint.py +++ b/src/utils/checkpoint.py @@ -24,7 +24,7 @@ UPLOAD_RETRIES = 8 WCTX_VALUES = ("scalars", "current_step") TMP_PATH_ADDON = "_____TEMPORARY" -GSUTIL_PATH = "/opt/google-cloud-sdk/bin/gsutil" +GSUTIL_PATH = "/snap/bin/gsutil" def log(arg: str, verbose: bool): From ba2e31151dad37832eedc3ded41eb67fc32808da Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sat, 20 Jan 2024 17:25:49 +0100 Subject: [PATCH 24/24] perf(dense): matmul during memory op --- src/model/dense.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/model/dense.py b/src/model/dense.py index f7819634..0f462791 100644 --- a/src/model/dense.py +++ b/src/model/dense.py @@ -40,7 +40,6 @@ def _fn(x: jax.Array): masked, padded = pattern_match(_get_mix_fn, max_dims, depth, inp) inp_glu = inp * lax.pad(inp[:, :-1], jnp.ones((), dtype=inp.dtype), ((0, 0, 0), (1, 0, 0), (0, 0, 0))) - inp = jnp.concatenate([inp, inp_glu, masked, padded], -1) - inp = conv(ctx, inp, 5, 4 * ctx.dims.features, 4 * ctx.dims.features) + inp = sum(conv(ctx, i, 5, ctx.dims.features, 4 * ctx.dims.features) for i in (inp, inp_glu, masked, padded)) return scale_norm_act_conv(ctx, inp, 5, 4 * ctx.dims.features, ctx.dims.features)