-
Notifications
You must be signed in to change notification settings - Fork 25
/
TODO
71 lines (39 loc) · 2.64 KB
/
TODO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
TODO
- do some proper performance testing on schaap for different b, ic, oc, i0/i1, f0/f1 combinations, make some pretty graphs
- add support for full convolutions (pad the input as well + extra padding for the filters)
* I think this is necessary for the gradient
- is it possible to support strided convolutions?
- build a convop optimisation that replaces the convop with the fft-product-ifft sequence.
- look for other optimisations for BatchedComplexDot and CuFFT
- previous attempt by Josh Bleecher Snyder using CuFFT directly: https://github.com/lumberlabs/fft_conv_op/blob/master/fft_conv_op.py
* we can steal the optimisation code from this to use it as a basis.
- FFTs can probably be done 'in place' as well, this could be faster. To make this feasible in Theano, we'd need an FFTOp and an InPlaceFFTOp, where the latter gets swapped in by an optimisation whenever this is possible.
* find out if scikits.cuda.fft.fft is destructive by default!
- we can just write an optimization that swaps the ConvOp for FFTGpuConv, like it currently swaps it for GpuConv, which is in turn a macro for GpuFFT+product+GpuIFFT. No need for gradients anywhere then!
process
input: (b, ic, i0, i1) float32
filters: (oc, ic, f0, f1) float32
* pad filters
filters: (oc, ic, i0, i1) float32 # same size as input
* reshape for FFT
input: (b * ic, i0, i1) float32
filters: (oc * ic, i0, i1) float32
* perform FFT
input: (b * ic, i0, i1//2 + 1) complex64
filters: (oc * ic, i0, i1//2 + 1) complex64
* reshape again to unfold ic dimension, and to separate b and oc
input: (b, 1, ic, i0, i1//2 + 1) complex64
filters: (1, oc, ic, i0, i1//2 + 1) complex64
* elementwise product
output: (b, oc, ic, i0, i1//2 + 1) complex64
* sum over the ic dimension (we can do this in the fourier domain because the FFT is linear)
output: (b, oc, i0, i1//2 + 1) complex64
* reshape for IFFT
output: (b * oc, i0, i1//2 + 1) complex64
* perform IFFT
output: (b * oc, i0, i1) float32
* slice because the convolution was circular
output: (b * oc, i0 - f0 + 1, i1 - f1 + 1) float32
* reshape
output: (b, oc, i0 - f0 + 1, i1 - f1 + 1) float32
=> result is a batched valid convolution of the input with the filters.