Skip to content

Commit 0c3d873

Browse files
committed
API tweaks and added docstrings for utility layers
1 parent 9ec00dc commit 0c3d873

File tree

2 files changed

+25
-28
lines changed

2 files changed

+25
-28
lines changed

src/layers.jl

+22-17
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ function mlpblock(planes, hidden_planes, dropout = 0., dense = Dense; activation
107107
dense(hidden_planes, planes, activation), Dropout(dropout))
108108
end
109109

110-
# Patching layer used by many vision transformer-like models
110+
"""
111+
Patching{T <: Integer}
112+
Patching layer used by many vision transformer-like models to split the input image into patches.
113+
Can be instantiated with a tuple `(patch_height, patch_width)` or a single value `patch_size`.
114+
"""
111115
struct Patching{T <: Integer}
112116
patch_height::T
113117
patch_width::T
@@ -125,32 +129,33 @@ end
125129

126130
@functor Patching
127131

128-
# Positional embedding layer used by many vision transformer-like models
129-
struct PosEmbedding
130-
embedding_vector
132+
"""
133+
PosEmbedding{T}
134+
135+
Positional embedding layer used by many vision transformer-like models. Instantiated with an
136+
embedding vector which is a learnable parameter.
137+
"""
138+
struct PosEmbedding{T}
139+
embedding_vector::T
131140
end
132141

133142
(p::PosEmbedding)(x) = x .+ p.embedding_vector[:, 1:size(x)[2], :]
134143

135144
@functor PosEmbedding
136145

137-
# Class tokens used by many vision transformer-like models
138-
struct CLSTokens
139-
cls_token
146+
"""
147+
CLSTokens{T}
148+
149+
Appends class tokens to the input that are used for classfication by many vision
150+
transformer-like models. Instantiated with a class token vector which is a learnable parameter.
151+
"""
152+
struct CLSTokens{T}
153+
cls_token::T
140154
end
141155

142156
function(m::CLSTokens)(x)
143157
cls_tokens = repeat(m.cls_token, 1, 1, size(x)[3])
144-
x = cat(cls_tokens, x; dims = 2)
158+
return cat(cls_tokens, x; dims = 2)
145159
end
146160

147161
@functor CLSTokens
148-
149-
# Utility function to decide if mean pooling happens inside the model
150-
struct CLSPooling
151-
mode
152-
end
153-
154-
(m::CLSPooling)(x) = (m.mode == "cls") ? x[:, 1, :] : _seconddimmean(x)
155-
156-
@functor CLSPooling

src/vit-based/vit.jl

+3-11
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@ end
4646

4747
@functor MHAttention
4848

49-
struct Transformer
50-
layers
51-
end
52-
5349
"""
5450
Transformer(planes, depth, heads, headplanes, mlppanes, dropout = 0.)
5551
@@ -69,13 +65,9 @@ function Transformer(planes, depth, heads, headplanes, mlpplanes, dropout = 0.)
6965
SkipConnection(prenorm(planes, mlpblock(planes, mlpplanes, dropout)), +))
7066
for _ in 1:depth]
7167

72-
Transformer(Chain(layers...))
68+
Chain(layers...)
7369
end
7470

75-
(m::Transformer)(x) = m.layers(x)
76-
77-
@functor Transformer
78-
7971
"""
8072
vit(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 16), planes = 1024,
8173
depth = 6, heads = 16, mlppanes = 2048, headplanes = 64, dropout = 0.1, emb_dropout = 0.1,
@@ -120,7 +112,7 @@ function vit(imsize::NTuple{2} = (256, 256); inchannels = 3, patch_size = (16, 1
120112
PosEmbedding(rand(Float32, (planes, num_patches + 1, 1))),
121113
Dropout(emb_dropout),
122114
Transformer(planes, depth, heads, headplanes, mlppanes, dropout),
123-
CLSPooling(pool),
115+
(pool == "cls") ? x -> x[:, 1, :] : x -> _seconddimmean(x),
124116
Chain(LayerNorm(planes), Dense(planes, nclasses)))
125117
end
126118

@@ -164,6 +156,6 @@ end
164156
(m::ViT)(x) = m.layers(x)
165157

166158
backbone(m::ViT) = m.layers[1:end-1]
167-
classifier(m::MLPMixer) = m.layers[end]
159+
classifier(m::ViT) = m.layers[end]
168160

169161
@functor ViT

0 commit comments

Comments
 (0)