diff --git a/README.md b/README.md index 24494c4..d25b255 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ from lambda_networks import λLayer Shinel94 has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under `./lambda_networks/tfkeras.py` or make sure to install `tensorflow` and `keras` before running the following. ```python +import tensorflow as tf from lambda_networks.tfkeras import LambdaLayer layer = LambdaLayer( @@ -70,8 +71,11 @@ layer = LambdaLayer( r = 23, dim_k = 16, heads = 4, - dim_u = 4 + dim_u = 1 ) + +x = tf.random.normal((1, 64, 64, 16)) +layer(x) ``` ## Citations diff --git a/lambda_networks/tfkeras.py b/lambda_networks/tfkeras.py index f51951c..e2c582b 100644 --- a/lambda_networks/tfkeras.py +++ b/lambda_networks/tfkeras.py @@ -1,6 +1,6 @@ -from einops.layers.keras import Rearrange -from keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer -from keras import initializers +from einops.layers.tensorflow import Rearrange +from tensorflow.keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer +from tensorflow.keras import initializers from tensorflow import einsum # helpers functions @@ -12,7 +12,6 @@ def exists(val): def default(val, d): return val if exists(val) else d - # lambda layer class LambdaLayer(Layer): @@ -46,8 +45,7 @@ def __init__( self.local_contexts = exists(r) if exists(r): assert (r % 2) == 1, 'Receptive kernel size should be odd' - self.pos_padding = ZeroPadding3D(padding=(0, r//2, r//2)) - self.pos_conv = Conv3D(dim_k, (1, r, r), padding='valid') + self.pos_conv = Conv3D(dim_k, (1, r, r), padding='same') else: assert exists(n), 'You must specify the total sequence length (h x w)' self.pos_emb = self.add_weight(name='pos_emb', @@ -56,7 +54,7 @@ def __init__( trainable=True) def call(self, inputs, **kwargs): - b, c, hh, ww = inputs.get_shape().as_list() + b, hh, ww, c = inputs.get_shape().as_list() u, h = self.u, self.heads x = inputs @@ -67,9 +65,9 @@ def call(self, inputs, **kwargs): q = self.norm_q(q) v = self.norm_v(v) - q = Rearrange('b (h k) hh ww -> b h k (hh ww)', h=h)(q) - k = Rearrange('b (u k) hh ww -> b u k (hh ww)', u=u)(k) - v = Rearrange('b (u v) hh ww -> b u v (hh ww)', u=u)(v) + q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q) + k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k) + v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v) k = Softmax()(k) @@ -77,23 +75,22 @@ def call(self, inputs, **kwargs): Yc = Lambda(lambda x: einsum('b h k n, b k v -> b n h v', x[0], x[1]))([q, Lc]) if self.local_contexts: - v = Rearrange('b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)(v) - Lp = self.pos_padding(v) - Lp = self.pos_conv(Lp) - Lp = Rearrange('b c k h w -> b c k (h w)')(Lp) - Yp = Lambda(lambda x: einsum('b h k n, b k v n -> b n h v', x[0], x[1]))([q, Lp]) + v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v) + Lp = self.pos_conv(v) + Lp = Rearrange('b v h w k -> b v k (h w)')(Lp) + Yp = Lambda(lambda x: einsum('b h k n, b v k n -> b n h v', x[0], x[1]))([q, Lp]) else: Lp = Lambda(lambda x: einsum('n m k u, b u v m -> b n k v', x[0], x[1]))([self.pos_emb, v]) Yp = Lambda(lambda x: einsum('b h k n, b n k v -> b n h v', x[0], x[1]))([q, Lp]) Y = Add()([Yc, Yp]) - out = Rearrange('b (hh ww) h v -> b (h v) hh ww', hh = hh, ww = ww)(Y) + out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y) return out def compute_output_shape(self, input_shape): - return (input_shape[0], self.out_dim, input_shape[2], input_shape[3]) + return (input_shape[0], input_shape[1], input_shape[2], self.out_dim) def get_config(self): - config = {'output_dim': (self.input_shape[0], self.out_dim, self.input_shape[2], self.input_shape[3])} + config = {'output_dim': (self.input_shape[0], self.input_shape[1], self.input_shape[2], self.out_dim)} base_config = super(LambdaLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/setup.py b/setup.py index 200f885..5a2caaf 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'lambda-networks', packages = find_packages(), - version = '0.3.0', + version = '0.3.1', license='MIT', description = 'Lambda Networks - Pytorch', author = 'Phil Wang',