Skip to content

Commit 31beee1

Browse files
Merge pull request #238 from jeremiedb/pool-bug
fix meanpool bug #229
2 parents 7c8fa27 + e9f1276 commit 31beee1

File tree

3 files changed

+85
-98
lines changed

3 files changed

+85
-98
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
deps/usr
1111
deps.jl
1212
*.log
13+
.vscode
1314
Manifest.toml

src/impl/pooling_direct.jl

+30-30
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Statistics
55
for name in (:max, :mean)
66
@eval function $((Symbol("$(name)pool_direct!")))(
77
y::AbstractArray{T,5}, x::AbstractArray{T,5},
8-
pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T}
8+
pdims::PoolDims; alpha::T=T(1), beta::T=T(0)) where {T}
99
@assert beta == T(0) "beta not supported yet"
1010
check_dims(size(x), size(y), pdims)
1111

@@ -22,17 +22,17 @@ for name in (:max, :mean)
2222
padded_regions, central_region = calc_padding_regions(pdims)
2323

2424
# A helper function to project from output (w, h) to input (input_w, input_h)
25-
@inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1
25+
@inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1
2626

2727
# If we're doing mean pooling, we represent division by kernel size by rolling it
2828
# into the `alpha` multiplier.
2929
if $(name == :mean)
30-
alpha = alpha/prod(kernel_size(pdims))
30+
alpha = alpha / prod(kernel_size(pdims))
3131
end
3232

3333
# Each loop, we initialize `m` to something, set that here.
3434
m_init = if $(name == :max)
35-
typemin(T)
35+
T <: AbstractFloat ? nextfloat(typemin(T)) : typemin(T)
3636
elseif $(name == :mean)
3737
T(0)
3838
else
@@ -54,9 +54,9 @@ for name in (:max, :mean)
5454
kh in 1:kernel_h,
5555
kw in 1:kernel_w
5656

57-
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
58-
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
59-
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
57+
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1) * dil_d
58+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1) * dil_h
59+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1) * dil_w
6060

6161
# This conditional will be optimized away at compile time
6262
if $(name == :max)
@@ -67,7 +67,7 @@ for name in (:max, :mean)
6767
error("Unimplemented codegen path")
6868
end
6969
end
70-
y[w, h, d, c, batch_idx] = alpha*m + beta*y[w, h, d, c, batch_idx]
70+
y[w, h, d, c, batch_idx] = alpha * m + beta * y[w, h, d, c, batch_idx]
7171
end
7272

7373
# Next, the padded regions
@@ -82,23 +82,23 @@ for name in (:max, :mean)
8282
# do so by putting in a bunch of conditionals. :/
8383
m = m_init
8484
for kd in 1:kernel_d
85-
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
85+
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1) * dil_d
8686
if input_kd <= 0 || input_kd > depth
87-
m = max(m, 0.0)
87+
# add here condition for handling options for paded value handling
8888
continue
8989
end
9090

9191
for kh in 1:kernel_h
92-
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
92+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1) * dil_h
9393
if input_kh <= 0 || input_kh > height
94-
m = max(m, 0.0)
94+
# add here condition for handling options for paded value handling
9595
continue
9696
end
9797

9898
for kw in 1:kernel_w
99-
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
99+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1) * dil_w
100100
if input_kw <= 0 || input_kw > width
101-
m = max(m, 0.0)
101+
# add here condition for handling options for paded value handling
102102
continue
103103
end
104104

@@ -112,7 +112,7 @@ for name in (:max, :mean)
112112
end
113113
end
114114
end
115-
y[w, h, d, c, batch_idx] = alpha*m + beta*y[w, h, d, c, batch_idx]
115+
y[w, h, d, c, batch_idx] = alpha * m + beta * y[w, h, d, c, batch_idx]
116116
end
117117
end
118118

@@ -125,7 +125,7 @@ for name in (:max, :mean)
125125
@eval function $((Symbol("$(name)pool_direct!")))(
126126
dx::AbstractArray{T,5}, dy::AbstractArray{T,5},
127127
y::AbstractArray{T,5}, x::AbstractArray{T,5},
128-
pdims::PoolDims; alpha::T = T(1), beta::T = T(0)) where {T}
128+
pdims::PoolDims; alpha::T=T(1), beta::T=T(0)) where {T}
129129
check_dims(size(x), size(dy), pdims)
130130

131131
width, height, depth = input_size(pdims)
@@ -141,12 +141,12 @@ for name in (:max, :mean)
141141
padded_regions, central_region = calc_padding_regions(pdims)
142142

143143
# A helper function to project from output (w, h) to input (input_w, input_h)
144-
@inline project(idx, stride, pad) = (idx - 1)*stride - pad + 1
144+
@inline project(idx, stride, pad) = (idx - 1) * stride - pad + 1
145145

146146
# If we're doing mean pooling, we represent division by kernel size by rolling
147147
# it into the `alpha` multiplier.
148148
if $(name == :mean)
149-
alpha = alpha/prod(kernel_size(pdims))
149+
alpha = alpha / prod(kernel_size(pdims))
150150
end
151151

152152
# Start with the central region
@@ -166,9 +166,9 @@ for name in (:max, :mean)
166166
kh in 1:kernel_h,
167167
kw in 1:kernel_w
168168

169-
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
170-
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
171-
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
169+
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1) * dil_d
170+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1) * dil_h
171+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1) * dil_w
172172

173173
# This conditional will be optimized away at compile time,
174174
# or my name isn't shengdan jingyu
@@ -179,15 +179,15 @@ for name in (:max, :mean)
179179
# Uncomment line below if using with non-precise output (e.g. by NNPACK)
180180
# if abs(y_idx - x[x_idxs...]) < 1e-5 && !maxpool_already_chose
181181
if y_idx x[x_idxs...] && !maxpool_already_chose
182-
dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...]
182+
dx[x_idxs...] += dy_idx * alpha + beta * dx[x_idxs...]
183183
maxpool_already_chose = true
184184
# Maxpooling does not support `beta` right now. :(
185-
#else
185+
# else
186186
# dx[x_idxs...] = T(0) + beta*dx[x_idxs...]
187187
end
188188
elseif $(name == :mean)
189189
# Either does meanpool :(
190-
dx[x_idxs...] = dy_idx*alpha + dx[x_idxs...]
190+
dx[x_idxs...] = dy_idx * alpha + dx[x_idxs...]
191191
else
192192
error("Unimplemented codegen path")
193193
end
@@ -210,19 +210,19 @@ for name in (:max, :mean)
210210
# In these loops, we have to check that we're not reaching off the edge,
211211
# we do so by putting in a bunch of conditionals. :/
212212
for kd in 1:kernel_d
213-
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1)*dil_d
213+
input_kd = project(d, stride_d, pad_d_lo) + (kd - 1) * dil_d
214214
if input_kd <= 0 || input_kd > depth
215215
continue
216216
end
217217

218218
for kh in 1:kernel_h
219-
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1)*dil_h
219+
input_kh = project(h, stride_h, pad_h_lo) + (kh - 1) * dil_h
220220
if input_kh <= 0 || input_kh > height
221221
continue
222222
end
223223

224224
for kw in 1:kernel_w
225-
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1)*dil_w
225+
input_kw = project(w, stride_w, pad_w_lo) + (kw - 1) * dil_w
226226
if input_kw <= 0 || input_kw > width
227227
continue
228228
end
@@ -233,13 +233,13 @@ for name in (:max, :mean)
233233
# Uncomment line below if using with non-precise output
234234
# if abs(y_idx - x[x_idxs...]) < 1e-5 && !maxpool_already_chose
235235
if y_idx x[x_idxs...] && !maxpool_already_chose
236-
dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...]
236+
dx[x_idxs...] += dy_idx * alpha + beta * dx[x_idxs...]
237237
maxpool_already_chose = true
238-
#else
238+
# else
239239
# dx[x_idxs...] = T(0) + beta*dx[x_idxs...]
240240
end
241241
elseif $(name == :mean)
242-
dx[x_idxs...] += dy_idx*alpha + beta*dx[x_idxs...]
242+
dx[x_idxs...] += dy_idx * alpha + beta * dx[x_idxs...]
243243
else
244244
error("Unimplemented codegen path")
245245
end

0 commit comments

Comments
 (0)