-
Notifications
You must be signed in to change notification settings - Fork 12
/
attention.py
81 lines (64 loc) · 3.16 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tensorflow as tf
class SelfAttnModel(tf.keras.Model):
def __init__(self, input_dims, data_format='channels_last', **kwargs):
super(SelfAttnModel, self).__init__(**kwargs)
self.attn = _Attention(data_format=data_format)
self.query_conv = tf.keras.layers.Conv2D(filters=input_dims//8,
kernel_size=1,
data_format=data_format)
self.key_conv = tf.keras.layers.Conv2D(filters=input_dims//8,
kernel_size=1,
data_format=data_format)
self.value_conv = tf.keras.layers.Conv2D(filters=input_dims,
kernel_size=1,
data_format=data_format)
def call(self, inputs, training=False):
q = self.query_conv(inputs)
k = self.key_conv(inputs)
v = self.value_conv(inputs)
return self.attn([q, k, v, inputs])
class _Attention(tf.keras.layers.Layer):
def __init__(self, data_format='channels_last', **kwargs):
super(_Attention, self).__init__(**kwargs)
self.data_format = data_format
def build(self, input_shapes):
self.gamma = self.add_weight(self.name + '_gamma',
shape=(),
initializer=tf.initializers.Zeros)
def call(self, inputs):
if len(inputs) != 4:
raise Exception('an attention layer should have 4 inputs')
query_tensor = inputs[0]
key_tensor = inputs[1]
value_tensor = inputs[2]
origin_input = inputs[3]
input_shape = tf.shape(query_tensor)
if self.data_format == 'channels_first':
height_axis = 2
width_axis = 3
else:
height_axis = 1
width_axis = 2
batchsize = input_shape[0]
height = input_shape[height_axis]
width = input_shape[width_axis]
if self.data_format == 'channels_first':
proj_query = tf.transpose(
tf.reshape(query_tensor, (batchsize, -1, height*width)),(0, 2, 1))
proj_key = tf.reshape(key_tensor, (batchsize, -1, height*width))
proj_value = tf.reshape(value_tensor, (batchsize, -1, height*width))
else:
proj_query = tf.reshape(query_tensor, (batchsize, height*width, -1))
proj_key = tf.transpose(
tf.reshape(key_tensor, (batchsize, height*width, -1)), (0, 2, 1))
proj_value = tf.transpose(
tf.reshape(value_tensor, (batchsize, height*width, -1)), (0, 2, 1))
energy = tf.matmul(proj_query, proj_key)
attention = tf.nn.softmax(energy)
out = tf.matmul(proj_value, tf.transpose(attention, (0, 2, 1)))
if self.data_format == 'channels_first':
out = tf.reshape(out, (batchsize, -1, height, width))
else:
out = tf.reshape(
tf.transpose(out, (0, 2, 1)), (batchsize, height, width, -1))
return tf.add(tf.multiply(out, self.gamma), origin_input), attention