go get github.com/surenderthakran/gomind
GoMind is a neural network library written entirely in Go. It only supports a single hidden layer (for now). The network learns from a training set using back-propagation algorithm.
Some of the salient features of GoMind are:
- Supports following activation functions:
- Linear (Default)
- Sigmoid
- ReLU
- Leaky ReLU
- Smartly estimates ideal number of hidden layer neurons (if a count is not given during model configuration) for given input and output sizes.
- Uses Mean Squared Error function to calculate error while back propagating.
Note: To understand the basic functioning of back-propagation in neural networks, one can refer to my blog here.
package main
import (
"github.com/singhsurender/gomind"
)
func main() {
trainingSet := [][][]float64{
[][]float64{[]float64{0, 0}, []float64{0}},
[][]float64{[]float64{0, 1}, []float64{1}},
[][]float64{[]float64{1, 0}, []float64{1}},
[][]float64{[]float64{1, 1}, []float64{0}},
}
mind, err := gomind.New(&gomind.ModelConfiguration{
NumberOfInputs: 2,
NumberOfOutputs: 1,
NumberOfHiddenLayerNeurons: 16,
HiddenLayerActivationFunctionName: "relu",
OutputLayerActivationFunctionName: "sigmoid",
})
if err != nil {
return nil, fmt.Errorf("unable to create neural network. %v", err)
}
for i := 0; i < 500; i++ {
rand := rand.Intn(4)
input := trainingSet[rand][0]
output := trainingSet[rand][1]
if err := mind.LearnSample(input, output); err != nil {
mind.Describe(true)
return nil, fmt.Errorf("error while learning from sample input: %v, target: %v. %v", input, output, err)
}
actual := mind.LastOutput()
outputError, err := mind.CalculateError(output)
if err != nil {
mind.Describe(true)
return nil, fmt.Errorf("error while calculating error for input: %v, target: %v and actual: %v. %v", input, output, actual, err)
}
outputAccuracy, err := mind.CalculateAccuracy(output)
if err != nil {
mind.Describe(true)
return nil, fmt.Errorf("error while calculating error for input: %v, target: %v and actual: %v. %v", input, output, actual, err)
}
fmt.Println("Index: %v, Input: %v, Target: %v, Actual: %v, Error: %v, Accuracy: %v\n", i, input, output, actual, outputError, outputAccuracy)
}
}
The documentation for various methods exposed by the library can be found at: https://godoc.org/github.com/surenderthakran/gomind