diff --git a/README.md b/README.md index b5861dcc..04445056 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,10 @@ running quickly and with a small memory footprint of O(N^1.5) instead of O(N^2). Therefore, an axial mlp-mixer needs less memory and compute than a standard transformer while providing better performance at scale. +**UPDATE** +This is now merged with the dense convolution by running `cat([x.T, x]) @ W_big` rather than `x @ W` and in a separate +block `x.T @ W_2`. + #### Reversible Most commonly, large transformers use [activation checkpointing](https://arxiv.org/abs/1604.06174v2), which saves the @@ -90,7 +94,6 @@ speedup for both training and inference. ### Optimizer - #### Adaptive Gradient Clipping ## Getting Started diff --git a/config.yaml b/config.yaml index 81e49957..ada90f9e 100644 --- a/config.yaml +++ b/config.yaml @@ -18,8 +18,6 @@ dims: moe_intermediate: 4096 one: 1 outer_bottleneck_kernel: 25 - pointwise_features: 512 - pointwise_kernel: 5 sequence: 4096 vocab: 256 global_prefix: '' @@ -47,7 +45,6 @@ optimizer: momentum_beta: 0.1 norm_scale: 1 output_scale: 1 - pointwise_scale: 1 preconditioning_compute_steps: 128 qrnn_scale: 1 skip_preconditioning_dim_size_gt: 1024 diff --git a/src/backend.py b/src/backend.py index 8e8a361e..5c349e38 100644 --- a/src/backend.py +++ b/src/backend.py @@ -209,8 +209,9 @@ def loop(fn: Callable, fn_input: Any, steps: int, unroll: int = 1): typevar = TypeVar("typevar") +output = TypeVar("output") -def pattern_match(gen_fn: Callable[[int], Callable[[typevar], jax.Array]], cases: int, - predicate: jax.Array, base: typevar): +def pattern_match(gen_fn: Callable[[int], Callable[[typevar], output]], cases: int, + predicate: jax.Array, base: typevar) -> output: return lax.switch(predicate.astype(jnp.int32) % cases, [gen_fn(i) for i in range(cases)], base) diff --git a/src/context.py b/src/context.py index 71cc0665..3c81bd89 100644 --- a/src/context.py +++ b/src/context.py @@ -75,10 +75,7 @@ class Dims(DataClass): outer_bottleneck_kernel: int = 25 inner_bottleneck_kernel: int = 49 inner_bottleneck_features: int = 128 - pointwise_kernel: int = 5 features: int = 256 - spatial_mixing_kernel: int = 512 - pointwise_features: int = 512 sequence: int = 4096 depth: int = 8 vocab: int = 256 diff --git a/src/model/activate.py b/src/model/activate.py index 01d2f779..52b23e50 100644 --- a/src/model/activate.py +++ b/src/model/activate.py @@ -1,13 +1,21 @@ import jax -from jax import numpy as jnp +from jax import lax, numpy as jnp + + +# [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992) showed that +# "complicated" activation functions such as SiLU, GeLU and Mish help increase the rank of the embedding matrix. +# To ensure low-redundancy features, we use a function from that family. +def _tanh_sp(inp: jax.Array) -> jax.Array: # jax.nn.softplus without nan check + return lax.tanh(jnp.maximum(inp, 0.) + lax.log1p(lax.exp(-lax.abs(inp)))) def activate_forward(inp: jax.Array) -> jax.Array: - return inp * activate_grad(inp) + return inp * _tanh_sp(inp) def activate_grad(inp: jax.Array) -> jax.Array: - return jnp.where(inp < 0, 0.01, 1) + tanh_sp = _tanh_sp(inp) + return tanh_sp + inp * jax.nn.sigmoid(inp) * (1 - lax.square(tanh_sp)) def activate(inp: jax.Array) -> jax.Array: diff --git a/src/model/conv.py b/src/model/conv.py index 1ffc4eca..983f74a8 100644 --- a/src/model/conv.py +++ b/src/model/conv.py @@ -1,19 +1,26 @@ +from typing import Tuple + import jax from jax import numpy as jnp -from src.backend import conv as lax_conv, get_param, square_grad, with_context +from src.backend import get_param, square_grad, with_context, conv as lax_conv from src.context import Context -from src.model.norm import prenorm, scale_norm_act @with_context() -def conv(ctx: Context, inp: jax.Array, conv_kernel: int, in_features: int, out_features: int, tied: bool = False): +def conv_weight(ctx: Context, conv_kernel: int, in_features: int, out_features: int, tied: bool = False + ) -> Tuple[jax.Array, jax.Array]: fan_in = jnp.arange(conv_kernel, 0, -1, dtype=ctx.model.storage_dtype) fan_in = (1 - 1 / (conv_kernel * ctx.model.conv_scale + ctx.model.conv_shift)) ** fan_in fan_in = fan_in / fan_in.sum() fan_in = fan_in.reshape(1, 1, -1) - weight, weight_sq = get_param(ctx, "conv_weight", [out_features, conv_kernel, in_features], column_axes=2, - lr_scale=fan_in, tied=tied, return_sq=True) + return get_param(ctx, "conv_weight", [out_features, conv_kernel, in_features], column_axes=2, lr_scale=fan_in, + tied=tied, return_sq=True) + + +@with_context() +def conv(ctx: Context, inp: jax.Array, conv_kernel: int, in_features: int, out_features: int, tied: bool = False): + weight, weight_sq = conv_weight(ctx, conv_kernel, in_features, out_features, tied) if ctx.is_initializing: return jnp.zeros(inp.shape[:-1] + (out_features,), dtype=inp.dtype) @@ -21,11 +28,3 @@ def _conv(x, y): return lax_conv(x, y, [(conv_kernel - 1, 0)], 1) return square_grad(_conv, inp, weight, weight_sq) - - -@prenorm -@with_context() -def dense_block(ctx: Context, inp: jax.Array) -> jax.Array: - inp = conv(ctx, inp, ctx.dims.pointwise_kernel, ctx.dims.features, ctx.dims.pointwise_features) - inp = scale_norm_act(ctx, inp, ctx.dims.pointwise_features) - return conv(ctx, inp, ctx.dims.pointwise_kernel, ctx.dims.pointwise_features, ctx.dims.features) diff --git a/src/model/dense.py b/src/model/dense.py new file mode 100644 index 00000000..0f462791 --- /dev/null +++ b/src/model/dense.py @@ -0,0 +1,45 @@ +import math + +import jax +from jax import numpy as jnp, lax + +from src.backend import with_context, pattern_match +from src.context import Context +from src.model.conv import conv +from src.model.norm import prenorm, scale_norm_act_conv + + +@prenorm +@with_context() +def dense_block(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: + # Following [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992), we're + # not increasing the dimensionality in the middle, as the rank doesn't increase -> no useful features are added. + + original_shape = inp.shape + original_batch, sequence, features = original_shape + max_dims = math.ceil(math.log(sequence, ctx.dims.features)) + + arange = jnp.arange(features) + mask = arange.reshape(1, -1, 1, 1) >= arange.reshape(1, 1, 1, -1) + + def _get_mix_fn(current_depth: int): + def _fn(x: jax.Array): + outer_seq = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1) + inner_seq = sequence // outer_seq # == dilation + inner = lax.broadcast_in_dim(mask, (outer_seq, features, inner_seq, features), (0, 1, 2, 3)) + inner = inner.reshape(1, -1, features)[:, :sequence] + + out = x.reshape(original_batch, outer_seq, features, inner_seq) + out = jnp.transpose(out, (0, 1, 3, 2)) + out = out.reshape(original_batch, sequence, features) + padded = lax.pad(out[:, :-inner_seq], jnp.zeros((), dtype=inp.dtype), + ((0, 0, 0), (inner_seq, 0, 0), (0, 0, 0))) + return out * inner, padded + + return _fn + + masked, padded = pattern_match(_get_mix_fn, max_dims, depth, inp) + inp_glu = inp * lax.pad(inp[:, :-1], jnp.ones((), dtype=inp.dtype), ((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + inp = sum(conv(ctx, i, 5, ctx.dims.features, 4 * ctx.dims.features) for i in (inp, inp_glu, masked, padded)) + return scale_norm_act_conv(ctx, inp, 5, 4 * ctx.dims.features, ctx.dims.features) diff --git a/src/model/main.py b/src/model/main.py index 409ea587..a2d51b47 100644 --- a/src/model/main.py +++ b/src/model/main.py @@ -5,9 +5,8 @@ from src.backend import get_param, is_model, is_stacked, square_grad, with_context from src.context import Context -from src.model.conv import dense_block +from src.model.dense import dense_block from src.model.loss import cross_entropy_loss -from src.model.mixer import mix from src.model.moe import dense_moe from src.model.norm import scale_norm_act from src.model.reversible import FourArrays, reversible, revnet_out @@ -37,10 +36,9 @@ def _fn(carry: FourArrays, inp: Tuple[Dict[str, jax.Array], jax.Array]): ctx.parameters.update(shared_params) depth = depth.reshape([]) src = [ctx.parameters] + list(carry) - src = reversible(ctx, dense_block, src) + src = reversible(ctx, dense_block, src, depth) src = reversible(ctx, dense_moe, src) - src = reversible(ctx, dense_block, src) - src = reversible(ctx, mix, src, depth) + src = reversible(ctx, dense_block, src, depth) name_cache.update(ctx.name_cache) if ctx.is_initializing: return src @@ -65,8 +63,7 @@ def stem(ctx: Context, src: FourArrays) -> FourArrays: return src -def body_ctx(ctx: Context, src: jax.Array) -> Union[ - Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: +def body_ctx(ctx: Context, src: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], jax.Array]: src = input_embed(ctx, src) zero = jnp.zeros_like(src) src = stem(ctx, (src, zero, src, zero)) diff --git a/src/model/mixer.py b/src/model/mixer.py deleted file mode 100644 index 7d2b42a1..00000000 --- a/src/model/mixer.py +++ /dev/null @@ -1,59 +0,0 @@ -import math -from typing import Sequence - -import jax -from jax import numpy as jnp - -from src.backend import dot, get_param, pattern_match, square_grad, with_context -from src.context import Context -from src.model.norm import prenorm, scale_norm_act - - -def dot_sq(src: jax.Array, weight: jax.Array, weight_sq: jax.Array, - left_contract_dims: Sequence[int], right_contract_dims: Sequence[int]): - def _dot(x, y): - return dot(x, y, left_contract_dims=left_contract_dims, right_contract_dims=right_contract_dims) - - return square_grad(_dot, src, weight, weight_sq) - - -@prenorm -@with_context() -def mix(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array: - weight_shape = [ctx.dims.spatial_mixing_kernel] * 2 - run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) - wgt0, wgt0_sq = get_param(ctx, "mix_0", weight_shape, return_sq=True) - wgt1, wgt1_sq = get_param(ctx, "mix_1", weight_shape, return_sq=True) - scale, scale_sq = get_param(ctx, "scale", [ctx.dims.features], std=0, mean=1, dtype=run_type, return_sq=True) - if ctx.is_initializing: - return inp - - original_shape = inp.shape - _batch, sequence, _features = original_shape - max_dims = math.ceil(math.log(sequence, ctx.dims.spatial_mixing_kernel)) - original_batch = inp.shape[0] - if ctx.model.autoregressive: - wgt0 = jnp.triu(wgt0) - wgt1 = jnp.triu(wgt1) - - def _get_mix_fn(current_depth: int): - def _fn(x: jax.Array): - batch = max(sequence // ctx.dims.spatial_mixing_kernel ** (current_depth % max_dims + 1), 1) - out = x.reshape(original_batch * batch, ctx.dims.spatial_mixing_kernel, -1) - inner_batch, inner_sequence, inner_features = out.shape - - # Shape[Batch, Sequence, Features] * Shape[Sequence, Sequence] -> Shape[Batch, Features, Sequence] - out = dot_sq(out, wgt0, wgt0_sq, left_contract_dims=(1,), right_contract_dims=(0,)) - - out = out.reshape(-1, ctx.dims.features, inner_sequence) - out = scale_norm_act(ctx, out, ctx.dims.features, weight=(scale, scale_sq), add_to_prefix=False, dim=1) - out = out.reshape(inner_batch, inner_features, inner_sequence) - - # Shape[Batch, Features, Sequence] * Shape[Sequence, Sequence] -> Shape[Batch, Features, Sequence] - out = dot_sq(out, wgt1, wgt1_sq, left_contract_dims=(2,), right_contract_dims=(0,)) - out = out.transpose(0, 2, 1) - return out.reshape(original_shape) - - return _fn - - return pattern_match(_get_mix_fn, max_dims, depth, inp) diff --git a/src/model/moe.py b/src/model/moe.py index d70ab46b..9073b90d 100644 --- a/src/model/moe.py +++ b/src/model/moe.py @@ -4,8 +4,7 @@ from src.backend import with_context from src.constants import ParallelAxes from src.context import Context -from src.model.conv import conv -from src.model.norm import prenorm, scale_norm_act +from src.model.norm import scale_norm_act, scale_norm_act_conv def all_to_all(ctx: Context, x: jax.Array, split_axis: int, concat_axis: int) -> jax.Array: @@ -22,7 +21,6 @@ def _grad(dy: jax.Array) -> jax.Array: return _fn(x) -@prenorm @with_context() def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: devices = jax.device_count() @@ -30,7 +28,7 @@ def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: batch, sequence, features = inp.shape sequence_slice = sequence // devices - inp = conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, features, ctx.dims.inner_bottleneck_features) + inp = scale_norm_act_conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, features, ctx.dims.inner_bottleneck_features) # [Batch, Sequence, Features] -> [Batch, SequenceSlice, Features * Devices] # In essence, 1) Collect features from all devices + 2) Drop unused sequence elements @@ -40,9 +38,7 @@ def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: inp = inp.reshape(batch, sequence_slice, big_params) # Devices^2 more parameters than normal bottleneck block but only Devices-times more flops due to sparsity above - inp = scale_norm_act(ctx, inp, big_params) - inp = conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, big_params, big_params, tied=True) - inp = scale_norm_act(ctx, inp, big_params) + inp = scale_norm_act_conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, big_params, big_params, tied=True) # [Batch, SequenceSlice, Features * Devices] -> [Batch, Sequence, Features] (PixelShuffle across devices) if not ctx.is_initializing: @@ -50,4 +46,5 @@ def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array: inp = all_to_all(ctx, inp, 3, 2) inp = inp.reshape(batch, sequence, ctx.dims.inner_bottleneck_features) - return conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, ctx.dims.inner_bottleneck_features, features) + out = scale_norm_act_conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, ctx.dims.inner_bottleneck_features, features) + return scale_norm_act(ctx, out, features, act=False) diff --git a/src/model/norm.py b/src/model/norm.py index f748800f..f81a955e 100644 --- a/src/model/norm.py +++ b/src/model/norm.py @@ -1,12 +1,12 @@ -from typing import Tuple, Optional, Union, Callable +from typing import Tuple, Optional, Union, Callable, List import jax from jax import lax, numpy as jnp -from src.backend import get_param, promote_to, stable_rsqrt, with_context -from src.constants import ParallelAxes +from src.backend import get_param, promote_to, stable_rsqrt, with_context, conv as lax_conv from src.context import Context from src.model.activate import activate_forward, activate_grad +from src.model.conv import conv_weight def prenorm(fn: Callable[[Context, jax.Array], jax.Array]): @@ -19,38 +19,60 @@ def _fn(ctx: Context, inp: jax.Array, *args) -> jax.Array: return _fn -def all_gather(inp: jax.Array, dim: int) -> jax.Array: - @jax.custom_gradient - def _fn(x): - def _grad(dy): - return lax.psum_scatter(dy, axis_name=ParallelAxes.model, scatter_dimension=dim, tiled=True) - - return lax.all_gather(x, axis_name=ParallelAxes.model, axis=dim, tiled=True), _grad - - return _fn(inp) - - -def norm_forward(ctx: Context, src: jax.Array, wgt: Optional[jax.Array] = None, psum: bool = False, - act: bool = True, dim: int = 2): +def norm_forward(ctx: Context, src: jax.Array, wgt: Optional[jax.Array] = None, act: bool = True, dim: int = 2, + double: bool = False, std: Optional[jax.Array] = None): run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) src_fp64 = promote_to(src, run_type) - own_sum = lax.square(src_fp64).sum(dim, keepdims=True) - if psum: - own_sum = lax.psum(own_sum, ParallelAxes.model) - std = stable_rsqrt(own_sum, ctx.model.norm.eps) - out = src_fp64 * std * wgt - if act: + if std is None: + own_sum = lax.square(src_fp64).sum(dim, keepdims=True) + std = stable_rsqrt(own_sum, ctx.model.norm.eps) + norm_out = src_fp64 * std + out = multiplied = norm_out * wgt + if act and double: + out = jnp.concatenate([activate_forward(out), activate_forward(-out)], dim) + elif act: out = activate_forward(out) out = out.astype(src.dtype) - if psum: - out = all_gather(out, dim) - return out, std + return out, norm_out, multiplied, src_fp64, std + + +def norm_backward(ctx: Context, src: jax.Array, wgt: jax.Array, std: jax.Array, dy: jax.Array, act: bool, + dim: int, double: bool, weight_shape: List[int], run_type: jnp.dtype, + src_fp64: Optional[jax.Array] = None, norm_out: Optional[jax.Array] = None, + bw_out: Optional[jax.Array] = None): + src_fp64 = promote_to(src, run_type) if src_fp64 is None else src_fp64 + norm_out = (src_fp64 * std) if norm_out is None else norm_out + dy = promote_to(dy, run_type) + if act: + bw_out = (norm_out * wgt) if bw_out is None else bw_out + if double: + dy0, dy1 = jnp.split(dy, 2, dim) + dy = dy0 * activate_grad(bw_out) - dy1 * activate_grad(-bw_out) # skipcq: PYL-E1130 + else: + dy *= activate_grad(bw_out) + d_normed = dy * wgt + + d_std = (d_normed * src_fp64).sum(dim, keepdims=True) # broadcast forward -> sum backward + d_std *= std ** 3 # reciprocal + x^(1/pow) -> 1/std^2 * 1/std^(pow-1) * 1/pow + d_std *= src_fp64 # x^pow -> pow * x^(pow-1), multiply fused with above + dx = d_normed * std - d_std + dx = dx.astype(src.dtype) + + if not isinstance(wgt, jax.Array): + return dx, None, None + + summed = list(range(src.ndim)) + del summed[dim] + d_wgt = dy * norm_out + d_wgt_sq = (lax.square(d_wgt).sum(summed) * ctx.dims.batch).reshape(weight_shape).astype(run_type) + d_wgt = d_wgt.sum(summed).reshape(weight_shape).astype(run_type) + return dx, d_wgt, d_wgt_sq @with_context() def scale_norm_act(ctx: Context, inp: jax.Array, feature_dim: int, - weight: Union[bool, None, Tuple[jax.Array, jax.Array]] = None, - psum: bool = False, act: bool = True, dim: int = 2) -> jax.Array: + weight: Union[bool, None, Tuple[jax.Array, jax.Array]] = None, act: bool = True, dim: int = 2, + double: bool = False) -> jax.Array: run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) if weight is None: weight, weight_sq = get_param(ctx, "scale", [feature_dim], std=0, mean=1, dtype=run_type, return_sq=True) @@ -67,35 +89,45 @@ def _fn(src: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): if isinstance(wgt, jax.Array): wgt = wgt.reshape((1,) * dim + (-1,) + (1,) * (src.ndim - 1 - dim)) - out, std = norm_forward(ctx, src, wgt, psum, act, dim) + out, _, _, _, std = norm_forward(ctx, src, wgt, act, dim, double) def _grad(dy: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], Tuple[jax.Array, None, None]]: - inner_src = lax.all_gather(src, ParallelAxes.model, axis=dim) if psum else src - src_fp64 = promote_to(inner_src, run_type) - norm_out = src_fp64 * std - dy = promote_to(dy, run_type) - if act: - dy = dy * activate_grad(norm_out * wgt) - d_normed = dy * wgt - - d_std = (d_normed * src_fp64).sum(dim, keepdims=True) # broadcast forward -> sum backward - d_std *= std ** 3 # reciprocal + x^(1/pow) -> 1/std^2 * 1/std^(pow-1) * 1/pow - d_std *= src_fp64 # x^pow -> pow * x^(pow-1), multiply fused with above - dx = d_normed * std - d_std - if psum: - dx = lax.psum_scatter(dx, axis_name=ParallelAxes.model, scatter_dimension=dim, tiled=True) - dx = dx.astype(src.dtype) - - if not isinstance(wgt, jax.Array): - return dx, None, None - - summed = list(range(src.ndim)) - del summed[dim] - d_wgt = dy * norm_out - d_wgt_sq = (lax.square(d_wgt).sum(summed) * ctx.dims.batch).reshape((-1,)).astype(run_type) - d_wgt = d_wgt.sum(summed).reshape((-1,)).astype(run_type) - return dx, d_wgt, d_wgt_sq + shp = _wgt_dummy.shape if isinstance(weight, jax.Array) else () + return norm_backward(ctx, src, wgt, std, dy, act, dim, double, shp, run_type) return out, _grad return _fn(inp, weight, weight_sq) + + +@with_context() +def scale_norm_act_conv(ctx: Context, inp: jax.Array, kernel: int, in_features: int, out_features: int, + tied: bool = False) -> jax.Array: + run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32) + scale, scale_sq = get_param(ctx, "scale", [in_features], std=0, mean=1, dtype=run_type, return_sq=True) + weight, weight_sq = conv_weight(ctx, kernel, in_features, out_features, tied) + + if ctx.is_initializing: + return jnp.zeros(inp.shape[:-1] + (out_features,), dtype=inp.dtype) + + dim = inp.ndim - 1 + + def _conv(x, y): + return lax_conv(x, y, [(kernel - 1, 0)], 1) + + @jax.custom_gradient + def _fn(src: jax.Array, scl: jax.Array, _scl_dummy: jax.Array, wgt: jax.Array, _wgt_dummy: jax.Array): + scl = scl.reshape((1,) * dim + (-1,)) + out, _, _, _, std = norm_forward(ctx, src, scl, True, dim, False) + + def _grad(dy: jax.Array) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + out2, norm_out, bw_out, src_fp64, _ = norm_forward(ctx, src, scl, True, dim, False, std) + _, d_wgt_sq = jax.vjp(_conv, lax.square(out2), wgt)[1](lax.square(dy)) + dy, d_wgt = jax.vjp(_conv, out2, wgt)[1](dy) + dx, d_scl, d_scl_sq = norm_backward(ctx, src, scl, std, dy, True, dim, False, _scl_dummy.shape, run_type, + src_fp64, norm_out, bw_out) + return dx, d_scl, d_scl_sq, d_wgt, d_wgt_sq + + return lax_conv(out, wgt, [(kernel - 1, 0)], 1), _grad + + return _fn(inp, scale, scale_sq, weight, weight_sq) diff --git a/src/utils/checkpoint.py b/src/utils/checkpoint.py index 58e51033..36362b6d 100644 --- a/src/utils/checkpoint.py +++ b/src/utils/checkpoint.py @@ -24,7 +24,7 @@ UPLOAD_RETRIES = 8 WCTX_VALUES = ("scalars", "current_step") TMP_PATH_ADDON = "_____TEMPORARY" -GSUTIL_PATH = "/opt/google-cloud-sdk/bin/gsutil" +GSUTIL_PATH = "/snap/bin/gsutil" def log(arg: str, verbose: bool): diff --git a/test.sh b/test.sh new file mode 100644 index 00000000..2dd13b96 --- /dev/null +++ b/test.sh @@ -0,0 +1,9 @@ +export JAX_ENABLE_X64=1 # allow fp64 +export JAX_DEFAULT_DTYPE_BITS=64 # ..and enforce it + +# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla.proto +export XLA_FLAGS="--xla_force_host_platform_device_count=8" # We don't use TPU-CPU for ML +# export XLA_FLAGS="--xla_step_marker_location=1 $XLA_FLAGS" # 0 = entry; 1 = outer while +export PYTHONPATH=`pwd` + +/usr/bin/env python3 -m pytest "$@" # for example: `bash test.sh unittests/grad/norm.py` diff --git a/unittests/consistency/step.py b/unittests/consistency/step.py index d561c54d..6f3ff73e 100644 --- a/unittests/consistency/step.py +++ b/unittests/consistency/step.py @@ -15,10 +15,8 @@ def get_wctx(config: typing.Optional[typing.Dict[str, typing.Any]] = None): ctx = wctx.ctx ctx.dims.batch = 16 - ctx.dims.spatial_mixing_kernel = 8 ctx.dims.sequence = 128 ctx.dims.features = 16 - ctx.dims.pointwise_features = 32 ctx.dims.inner_bottleneck_features = 8 return wctx, ctx diff --git a/unittests/grad/backend.py b/unittests/grad/backend.py index ae1b2759..823d601e 100644 --- a/unittests/grad/backend.py +++ b/unittests/grad/backend.py @@ -14,18 +14,19 @@ def randn_fn(): def _fn(*shape: int): seed = rng.randint(0, 2 ** 30) + + def _gen(x): + return jax.random.normal(jax.random.PRNGKey(x + seed), shape, jnp.float32).astype(jnp.float64) / div + div = (shape[-1] * jax.device_count()) ** 0.25 - fn = jax.pmap( - lambda x: jax.random.normal(jax.random.PRNGKey(x + seed), shape, jnp.float32).astype(jnp.float_) / div) - local_devices = jax.local_device_count() - seeds = jnp.arange(local_devices * jax.process_index(), local_devices * (1 + jax.process_index())) - return fn(seeds) + devices = jax.local_device_count() + return jax.pmap(_gen)(jnp.arange(devices * jax.process_index(), devices * (1 + jax.process_index()))) return _fn def grad_fn(dy: jax.Array, *args): def _fn(fn): - return jax.pmap(jax.grad(lambda x: (fn(x) * dy).sum()), ParallelAxes.model)(args) + return jax.pmap(jax.grad(lambda x, y: (fn(x) * y).sum()), ParallelAxes.model)(args, dy) return _fn diff --git a/unittests/grad/leak.py b/unittests/grad/leak.py index 0f2568a3..23b970c8 100644 --- a/unittests/grad/leak.py +++ b/unittests/grad/leak.py @@ -7,6 +7,7 @@ from src.constants import ParallelAxes from src.context import Context +from src.main import add_zeros from src.model.main import stem from src.model.reversible import revnet_out from unittests.grad.backend import grad_fn, randn_fn, trials @@ -42,8 +43,6 @@ def test(samples: int, depth: int): ctx.dims.depth = depth ctx.dims.features = 8 ctx.dims.inner_bottleneck_features = 4 - ctx.dims.pointwise_features = 16 - ctx.dims.spatial_mixing_kernel = ctx.dims.sequence // 2 src = randn(ctx.dims.batch, ctx.dims.sequence, ctx.dims.features).astype(jnp.bfloat16) def _fn(x: jax.Array): @@ -53,6 +52,7 @@ def _fn(x: jax.Array): return params params = jax.pmap(_fn, ParallelAxes.model)(src) + add_zeros(params) ctx.is_initializing = False def _inner(inp: typing.Tuple[typing.Dict[str, jax.Array], jax.Array]): diff --git a/unittests/grad/norm.py b/unittests/grad/norm.py index 9a6c39a4..d9d56be4 100644 --- a/unittests/grad/norm.py +++ b/unittests/grad/norm.py @@ -1,4 +1,3 @@ -import jax import pytest from jax import numpy as jnp @@ -7,25 +6,25 @@ from unittests.grad.backend import grad_fn, randn_fn, sample_sizes, trials -def general_test(act: bool, psum: bool, samples: int, dim: int): # skipcq: PYL-W0640 +def general_test(act: bool, samples: int, dim: int, double: bool): # skipcq: PYL-W0640 ctx = Context() ctx.is_initializing = False randn = randn_fn() for trial in range(trials): src = randn(int(samples ** 0.5), int(samples ** 0.5), ctx.dims.features) - multiplier = jax.device_count() if psum else 1 out_shape = list(src.shape)[1:] - out_shape[dim] *= multiplier wgt = randn(out_shape[dim]) wgt_sq = randn(out_shape[dim]) + if double: + out_shape[dim] *= 2 dy = randn(*out_shape) print(dy.shape, src.shape, wgt.shape) grad = grad_fn(dy, src, wgt, wgt_sq) print(trial) shape = (1,) * dim + (-1,) + (1,) * (src.ndim - 2 - dim) - out0 = grad(lambda x: norm_forward(ctx, x[0], x[1].reshape(shape), bool(psum), act, dim)[0]) - out1 = grad(lambda x: scale_norm_act(ctx, x[0], ctx.dims.features, (x[1], x[2]), bool(psum), act, dim)) + out0 = grad(lambda x: norm_forward(ctx, x[0], x[1].reshape(shape), act, dim, double)[0]) + out1 = grad(lambda x: scale_norm_act(ctx, x[0], ctx.dims.features, (x[1], x[2]), act, dim, double)) assert jnp.allclose(out0[0], out1[0]) assert jnp.allclose(out0[1], out1[1]) @@ -34,16 +33,16 @@ def general_test(act: bool, psum: bool, samples: int, dim: int): # skipcq: PYL- @pytest.mark.parametrize("act", [True, False]) @pytest.mark.parametrize("samples", sample_sizes) def test_act(act: bool, samples: int): - general_test(act, False, samples, 2) + general_test(act, samples, 2, False) -@pytest.mark.parametrize("psum", [False, True]) +@pytest.mark.parametrize("dim", [0, 1, 2]) @pytest.mark.parametrize("samples", sample_sizes) -def test_psum(psum: bool, samples: int): - general_test(True, psum, samples, 2) +def test_dim(dim: int, samples: int): + general_test(True, samples, dim, False) -@pytest.mark.parametrize("dim", [0, 1, 2]) +@pytest.mark.parametrize("double", [False, True]) @pytest.mark.parametrize("samples", sample_sizes) -def test_dim(dim: int, samples: int): - general_test(True, False, samples, dim) +def test_double(double: bool, samples: int): + general_test(True, samples, 2, double)