Skip to content

GistNoesis/FourierKAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 

Repository files navigation

FourierKAN

Pytorch Layer for FourierKAN

It is a layer intended to be a substitution for Linear + non-linear activation

This is inspired by Kolmogorov-Arnold Networks but using 1d fourier coefficients instead of splines coefficients It should be easier to optimize as fourier are more dense than spline (global vs local) Once convergence is reached you can replace the 1d function with spline approximation for faster evaluation giving almost the same result The other advantage of using fourier over spline is that the function are periodic, and therefore more numerically bounded Avoiding the issues of going out of grid

Usage

put the file in the same directory then

from fftKAN import NaiveFourierKANLayer

alternatively you can run python fftKAN.py

to see the demo.

Code runs, cpu and gpu, but is untested.

This is a naive version that use memory proportional to the gridsize, where as a fused version doesn't require temporary memory

Training

The higher frequency terms of the fourier coefficients may make the training difficult as the function will not be very smooth.

@JeremyIV suggested a brownian noise intialisation for the fourier coefficients (See PR #4 ), you can try it by constructing the layer with the flag smooth_initialization=True

One usual way of dealing with Fourier higher frequency terms, is adding a regularization term which penalize the higher frequencies in the way you want. The merit of that being that the function will be enforced smoothed as training progresses, and not just at initialization.

Highlight of the core :

You can either do the simple thing of materializing the product and then do the sum, or you can use einsum to do the reduction. Einsum should use less memory but be slower

FourierKAN/fftKAN.py

Lines 28 to 58 in 9a8c753

def forward(self,x):
xshp = x.shape
outshape = xshp[0:-1]+(self.outdim,)
x = th.reshape(x,(-1,self.inputdim))
#Starting at 1 because constant terms are in the bias
k = th.reshape( th.arange(1,self.gridsize+1,device=x.device),(1,1,1,self.gridsize))
xrshp = th.reshape(x,(x.shape[0],1,x.shape[1],1) )
#This should be fused to avoid materializing memory
c = th.cos( k*xrshp )
s = th.sin( k*xrshp )
#We compute the interpolation of the various functions defined by their fourier coefficient for each input coordinates and we sum them
y = th.sum( c*self.fouriercoeffs[0:1],(-2,-1))
y += th.sum( s*self.fouriercoeffs[1:2],(-2,-1))
if( self.addbias):
y += self.bias
#End fuse
'''
#You can use einsum instead to reduce memory usage
#It stills not as good as fully fused but it should help
#einsum is usually slower though
c = th.reshape(c,(1,x.shape[0],x.shape[1],self.gridsize))
s = th.reshape(s,(1,x.shape[0],x.shape[1],self.gridsize))
y2 = th.einsum( "dbik,djik->bj", th.concat([c,s],axis=0) ,self.fouriercoeffs )
if( self.addbias):
y2 += self.bias
diff = th.sum((y2-y)**2)
print("diff")
print(diff) #should be ~0
'''
y = th.reshape( y, outshape)
return y

License

License is MIT, but future evolutions (including fused kernels ) will be proprietary.

Fused Operations

This layer use a lot of memory, but by fusing operations we don't need any extra memory, and we can even use trigonometry tricks.

https://github.com/GistNoesis/FusedFourierKAN

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages