Skip to content

Commit

Permalink
fix a bunch of bugs with tf/keras version
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 20, 2020
1 parent d5abbf3 commit 3d581cf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,19 @@ from lambda_networks import λLayer
<a href="https://github.com/shinel94">Shinel94</a> has added a Keras implementation! It won't be officially supported in this repository, so either copy / paste the code under `./lambda_networks/tfkeras.py` or make sure to install `tensorflow` and `keras` before running the following.

```python
import tensorflow as tf
from lambda_networks.tfkeras import LambdaLayer

layer = LambdaLayer(
dim_out = 32,
r = 23,
dim_k = 16,
heads = 4,
dim_u = 4
dim_u = 1
)

x = tf.random.normal((1, 64, 64, 16))
layer(x)
```

## Citations
Expand Down
33 changes: 15 additions & 18 deletions lambda_networks/tfkeras.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from einops.layers.keras import Rearrange
from keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer
from keras import initializers
from einops.layers.tensorflow import Rearrange
from tensorflow.keras.layers import Conv2D, BatchNormalization, Conv3D, ZeroPadding3D, Softmax, Lambda, Add, Layer
from tensorflow.keras import initializers
from tensorflow import einsum

# helpers functions
Expand All @@ -12,7 +12,6 @@ def exists(val):
def default(val, d):
return val if exists(val) else d


# lambda layer

class LambdaLayer(Layer):
Expand Down Expand Up @@ -46,8 +45,7 @@ def __init__(
self.local_contexts = exists(r)
if exists(r):
assert (r % 2) == 1, 'Receptive kernel size should be odd'
self.pos_padding = ZeroPadding3D(padding=(0, r//2, r//2))
self.pos_conv = Conv3D(dim_k, (1, r, r), padding='valid')
self.pos_conv = Conv3D(dim_k, (1, r, r), padding='same')
else:
assert exists(n), 'You must specify the total sequence length (h x w)'
self.pos_emb = self.add_weight(name='pos_emb',
Expand All @@ -56,7 +54,7 @@ def __init__(
trainable=True)

def call(self, inputs, **kwargs):
b, c, hh, ww = inputs.get_shape().as_list()
b, hh, ww, c = inputs.get_shape().as_list()
u, h = self.u, self.heads
x = inputs

Expand All @@ -67,33 +65,32 @@ def call(self, inputs, **kwargs):
q = self.norm_q(q)
v = self.norm_v(v)

q = Rearrange('b (h k) hh ww -> b h k (hh ww)', h=h)(q)
k = Rearrange('b (u k) hh ww -> b u k (hh ww)', u=u)(k)
v = Rearrange('b (u v) hh ww -> b u v (hh ww)', u=u)(v)
q = Rearrange('b hh ww (h k) -> b h k (hh ww)', h=h)(q)
k = Rearrange('b hh ww (u k) -> b u k (hh ww)', u=u)(k)
v = Rearrange('b hh ww (u v) -> b u v (hh ww)', u=u)(v)

k = Softmax()(k)

Lc = Lambda(lambda x: einsum('b u k m, b u v m -> b k v', x[0], x[1]))([k, v])
Yc = Lambda(lambda x: einsum('b h k n, b k v -> b n h v', x[0], x[1]))([q, Lc])

if self.local_contexts:
v = Rearrange('b u v (hh ww) -> b u v hh ww', hh=hh, ww=ww)(v)
Lp = self.pos_padding(v)
Lp = self.pos_conv(Lp)
Lp = Rearrange('b c k h w -> b c k (h w)')(Lp)
Yp = Lambda(lambda x: einsum('b h k n, b k v n -> b n h v', x[0], x[1]))([q, Lp])
v = Rearrange('b u v (hh ww) -> b v hh ww u', hh=hh, ww=ww)(v)
Lp = self.pos_conv(v)
Lp = Rearrange('b v h w k -> b v k (h w)')(Lp)
Yp = Lambda(lambda x: einsum('b h k n, b v k n -> b n h v', x[0], x[1]))([q, Lp])
else:
Lp = Lambda(lambda x: einsum('n m k u, b u v m -> b n k v', x[0], x[1]))([self.pos_emb, v])
Yp = Lambda(lambda x: einsum('b h k n, b n k v -> b n h v', x[0], x[1]))([q, Lp])

Y = Add()([Yc, Yp])
out = Rearrange('b (hh ww) h v -> b (h v) hh ww', hh = hh, ww = ww)(Y)
out = Rearrange('b (hh ww) h v -> b hh ww (h v)', hh = hh, ww = ww)(Y)
return out

def compute_output_shape(self, input_shape):
return (input_shape[0], self.out_dim, input_shape[2], input_shape[3])
return (input_shape[0], input_shape[1], input_shape[2], self.out_dim)

def get_config(self):
config = {'output_dim': (self.input_shape[0], self.out_dim, self.input_shape[2], self.input_shape[3])}
config = {'output_dim': (self.input_shape[0], self.input_shape[1], self.input_shape[2], self.out_dim)}
base_config = super(LambdaLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'lambda-networks',
packages = find_packages(),
version = '0.3.0',
version = '0.3.1',
license='MIT',
description = 'Lambda Networks - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 3d581cf

Please sign in to comment.