Skip to content
Open

Arch2 #105

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
00c269c
test(optimizer): no truegrad
ClashLuke Dec 30, 2022
cdc8d0d
feat(model): use mish + merge mixer/conv
ClashLuke Jan 22, 2023
ae6b483
Revert "test(optimizer): no truegrad"
ClashLuke Jan 22, 2023
bf68bf6
fix(conv): remove leak
ClashLuke Jan 22, 2023
99f2d25
fix: lots of minor issue
ClashLuke Jan 22, 2023
e28bd98
feat(tests): add test script
ClashLuke Jan 22, 2023
f5dbac3
style(main): fix type hint indentation
ClashLuke Jan 22, 2023
6e7130e
perf(dense): use X+GLU+Mixer+Shift as input to dense; add more effici…
ClashLuke Jan 24, 2023
0a1ea7c
style: fix deepsource issues
ClashLuke Jan 25, 2023
504153f
fix(dense): use correct lax.pad dtype
ClashLuke Feb 8, 2023
a12d843
fix(dense): max path doesn't work with pad
ClashLuke Feb 8, 2023
81b153a
fix(dense): *features was for numel, not idx in dim
ClashLuke Feb 8, 2023
7907a3d
fix: add missing () to with_context
ClashLuke Feb 8, 2023
9292d83
fix(dense): handle unreshapable mask case
ClashLuke Feb 8, 2023
4b8eb08
fix(activate): casting issue
ClashLuke Feb 8, 2023
99ddce3
fix(norm): normalize out_features
ClashLuke Feb 8, 2023
b6da249
fix(norm): revert
ClashLuke Feb 8, 2023
ee000be
test(norm): show shapes
ClashLuke Feb 8, 2023
553fd42
test(norm): show name cache and prefix
ClashLuke Feb 8, 2023
289a9f1
fix(norm): change shape in norm conv
ClashLuke Feb 8, 2023
2b728d2
fix(norm): use correct shape
ClashLuke Feb 8, 2023
ea36c42
fix(norm): only get shape if array
ClashLuke Feb 8, 2023
d57a657
v4
ClashLuke Feb 8, 2023
ba2e311
perf(dense): matmul during memory op
ClashLuke Jan 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ running quickly and with a small memory footprint of O(N^1.5) instead of O(N^2).
Therefore, an axial mlp-mixer needs less memory and compute than a standard transformer while providing better
performance at scale.

**UPDATE**
This is now merged with the dense convolution by running `cat([x.T, x]) @ W_big` rather than `x @ W` and in a separate
block `x.T @ W_2`.

#### Reversible

Most commonly, large transformers use [activation checkpointing](https://arxiv.org/abs/1604.06174v2), which saves the
Expand Down Expand Up @@ -90,7 +94,6 @@ speedup for both training and inference.

### Optimizer


#### Adaptive Gradient Clipping

## Getting Started
Expand Down
3 changes: 0 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ dims:
moe_intermediate: 4096
one: 1
outer_bottleneck_kernel: 25
pointwise_features: 512
pointwise_kernel: 5
sequence: 4096
vocab: 256
global_prefix: ''
Expand Down Expand Up @@ -47,7 +45,6 @@ optimizer:
momentum_beta: 0.1
norm_scale: 1
output_scale: 1
pointwise_scale: 1
preconditioning_compute_steps: 128
qrnn_scale: 1
skip_preconditioning_dim_size_gt: 1024
Expand Down
5 changes: 3 additions & 2 deletions src/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ def loop(fn: Callable, fn_input: Any, steps: int, unroll: int = 1):


typevar = TypeVar("typevar")
output = TypeVar("output")


def pattern_match(gen_fn: Callable[[int], Callable[[typevar], jax.Array]], cases: int,
predicate: jax.Array, base: typevar):
def pattern_match(gen_fn: Callable[[int], Callable[[typevar], output]], cases: int,
predicate: jax.Array, base: typevar) -> output:
return lax.switch(predicate.astype(jnp.int32) % cases, [gen_fn(i) for i in range(cases)], base)
3 changes: 0 additions & 3 deletions src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ class Dims(DataClass):
outer_bottleneck_kernel: int = 25
inner_bottleneck_kernel: int = 49
inner_bottleneck_features: int = 128
pointwise_kernel: int = 5
features: int = 256
spatial_mixing_kernel: int = 512
pointwise_features: int = 512
sequence: int = 4096
depth: int = 8
vocab: int = 256
Expand Down
14 changes: 11 additions & 3 deletions src/model/activate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import jax
from jax import numpy as jnp
from jax import lax, numpy as jnp


# [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992) showed that
# "complicated" activation functions such as SiLU, GeLU and Mish help increase the rank of the embedding matrix.
# To ensure low-redundancy features, we use a function from that family.
def _tanh_sp(inp: jax.Array) -> jax.Array: # jax.nn.softplus without nan check
return lax.tanh(jnp.maximum(inp, 0.) + lax.log1p(lax.exp(-lax.abs(inp))))


def activate_forward(inp: jax.Array) -> jax.Array:
return inp * activate_grad(inp)
return inp * _tanh_sp(inp)


def activate_grad(inp: jax.Array) -> jax.Array:
return jnp.where(inp < 0, 0.01, 1)
tanh_sp = _tanh_sp(inp)
return tanh_sp + inp * jax.nn.sigmoid(inp) * (1 - lax.square(tanh_sp))


def activate(inp: jax.Array) -> jax.Array:
Expand Down
25 changes: 12 additions & 13 deletions src/model/conv.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
from typing import Tuple

import jax
from jax import numpy as jnp

from src.backend import conv as lax_conv, get_param, square_grad, with_context
from src.backend import get_param, square_grad, with_context, conv as lax_conv
from src.context import Context
from src.model.norm import prenorm, scale_norm_act


@with_context()
def conv(ctx: Context, inp: jax.Array, conv_kernel: int, in_features: int, out_features: int, tied: bool = False):
def conv_weight(ctx: Context, conv_kernel: int, in_features: int, out_features: int, tied: bool = False
) -> Tuple[jax.Array, jax.Array]:
fan_in = jnp.arange(conv_kernel, 0, -1, dtype=ctx.model.storage_dtype)
fan_in = (1 - 1 / (conv_kernel * ctx.model.conv_scale + ctx.model.conv_shift)) ** fan_in
fan_in = fan_in / fan_in.sum()
fan_in = fan_in.reshape(1, 1, -1)
weight, weight_sq = get_param(ctx, "conv_weight", [out_features, conv_kernel, in_features], column_axes=2,
lr_scale=fan_in, tied=tied, return_sq=True)
return get_param(ctx, "conv_weight", [out_features, conv_kernel, in_features], column_axes=2, lr_scale=fan_in,
tied=tied, return_sq=True)


@with_context()
def conv(ctx: Context, inp: jax.Array, conv_kernel: int, in_features: int, out_features: int, tied: bool = False):
weight, weight_sq = conv_weight(ctx, conv_kernel, in_features, out_features, tied)
if ctx.is_initializing:
return jnp.zeros(inp.shape[:-1] + (out_features,), dtype=inp.dtype)

def _conv(x, y):
return lax_conv(x, y, [(conv_kernel - 1, 0)], 1)

return square_grad(_conv, inp, weight, weight_sq)


@prenorm
@with_context()
def dense_block(ctx: Context, inp: jax.Array) -> jax.Array:
inp = conv(ctx, inp, ctx.dims.pointwise_kernel, ctx.dims.features, ctx.dims.pointwise_features)
inp = scale_norm_act(ctx, inp, ctx.dims.pointwise_features)
return conv(ctx, inp, ctx.dims.pointwise_kernel, ctx.dims.pointwise_features, ctx.dims.features)
45 changes: 45 additions & 0 deletions src/model/dense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import math

import jax
from jax import numpy as jnp, lax

from src.backend import with_context, pattern_match
from src.context import Context
from src.model.conv import conv
from src.model.norm import prenorm, scale_norm_act_conv


@prenorm
@with_context()
def dense_block(ctx: Context, inp: jax.Array, depth: jax.Array) -> jax.Array:
# Following [Rethinking Channel Dimensions for Efficient Model Design](https://arxiv.org/abs/2007.00992), we're
# not increasing the dimensionality in the middle, as the rank doesn't increase -> no useful features are added.

original_shape = inp.shape
original_batch, sequence, features = original_shape
max_dims = math.ceil(math.log(sequence, ctx.dims.features))

arange = jnp.arange(features)
mask = arange.reshape(1, -1, 1, 1) >= arange.reshape(1, 1, 1, -1)

def _get_mix_fn(current_depth: int):
def _fn(x: jax.Array):
outer_seq = max(sequence // ctx.dims.features ** (current_depth % max_dims + 1), 1)
inner_seq = sequence // outer_seq # == dilation
inner = lax.broadcast_in_dim(mask, (outer_seq, features, inner_seq, features), (0, 1, 2, 3))
inner = inner.reshape(1, -1, features)[:, :sequence]

out = x.reshape(original_batch, outer_seq, features, inner_seq)
out = jnp.transpose(out, (0, 1, 3, 2))
out = out.reshape(original_batch, sequence, features)
padded = lax.pad(out[:, :-inner_seq], jnp.zeros((), dtype=inp.dtype),
((0, 0, 0), (inner_seq, 0, 0), (0, 0, 0)))
return out * inner, padded

return _fn

masked, padded = pattern_match(_get_mix_fn, max_dims, depth, inp)
inp_glu = inp * lax.pad(inp[:, :-1], jnp.ones((), dtype=inp.dtype), ((0, 0, 0), (1, 0, 0), (0, 0, 0)))

inp = sum(conv(ctx, i, 5, ctx.dims.features, 4 * ctx.dims.features) for i in (inp, inp_glu, masked, padded))
return scale_norm_act_conv(ctx, inp, 5, 4 * ctx.dims.features, ctx.dims.features)
11 changes: 4 additions & 7 deletions src/model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

from src.backend import get_param, is_model, is_stacked, square_grad, with_context
from src.context import Context
from src.model.conv import dense_block
from src.model.dense import dense_block
from src.model.loss import cross_entropy_loss
from src.model.mixer import mix
from src.model.moe import dense_moe
from src.model.norm import scale_norm_act
from src.model.reversible import FourArrays, reversible, revnet_out
Expand Down Expand Up @@ -37,10 +36,9 @@ def _fn(carry: FourArrays, inp: Tuple[Dict[str, jax.Array], jax.Array]):
ctx.parameters.update(shared_params)
depth = depth.reshape([])
src = [ctx.parameters] + list(carry)
src = reversible(ctx, dense_block, src)
src = reversible(ctx, dense_block, src, depth)
src = reversible(ctx, dense_moe, src)
src = reversible(ctx, dense_block, src)
src = reversible(ctx, mix, src, depth)
src = reversible(ctx, dense_block, src, depth)
name_cache.update(ctx.name_cache)
if ctx.is_initializing:
return src
Expand All @@ -65,8 +63,7 @@ def stem(ctx: Context, src: FourArrays) -> FourArrays:
return src


def body_ctx(ctx: Context, src: jax.Array) -> Union[
Tuple[jax.Array, jax.Array, jax.Array], jax.Array]:
def body_ctx(ctx: Context, src: jax.Array) -> Union[Tuple[jax.Array, jax.Array, jax.Array], jax.Array]:
src = input_embed(ctx, src)
zero = jnp.zeros_like(src)
src = stem(ctx, (src, zero, src, zero))
Expand Down
59 changes: 0 additions & 59 deletions src/model/mixer.py

This file was deleted.

13 changes: 5 additions & 8 deletions src/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from src.backend import with_context
from src.constants import ParallelAxes
from src.context import Context
from src.model.conv import conv
from src.model.norm import prenorm, scale_norm_act
from src.model.norm import scale_norm_act, scale_norm_act_conv


def all_to_all(ctx: Context, x: jax.Array, split_axis: int, concat_axis: int) -> jax.Array:
Expand All @@ -22,15 +21,14 @@ def _grad(dy: jax.Array) -> jax.Array:
return _fn(x)


@prenorm
@with_context()
def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array:
devices = jax.device_count()
big_params = devices * ctx.dims.inner_bottleneck_features
batch, sequence, features = inp.shape
sequence_slice = sequence // devices

inp = conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, features, ctx.dims.inner_bottleneck_features)
inp = scale_norm_act_conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, features, ctx.dims.inner_bottleneck_features)

# [Batch, Sequence, Features] -> [Batch, SequenceSlice, Features * Devices]
# In essence, 1) Collect features from all devices + 2) Drop unused sequence elements
Expand All @@ -40,14 +38,13 @@ def dense_moe(ctx: Context, inp: jax.Array) -> jax.Array:
inp = inp.reshape(batch, sequence_slice, big_params)

# Devices^2 more parameters than normal bottleneck block but only Devices-times more flops due to sparsity above
inp = scale_norm_act(ctx, inp, big_params)
inp = conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, big_params, big_params, tied=True)
inp = scale_norm_act(ctx, inp, big_params)
inp = scale_norm_act_conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, big_params, big_params, tied=True)

# [Batch, SequenceSlice, Features * Devices] -> [Batch, Sequence, Features] (PixelShuffle across devices)
if not ctx.is_initializing:
inp = inp.reshape(batch, sequence_slice, 1, big_params)
inp = all_to_all(ctx, inp, 3, 2)
inp = inp.reshape(batch, sequence, ctx.dims.inner_bottleneck_features)

return conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, ctx.dims.inner_bottleneck_features, features)
out = scale_norm_act_conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, ctx.dims.inner_bottleneck_features, features)
return scale_norm_act(ctx, out, features, act=False)
Loading