Commit e9761a14 authored by Jacky Lin's avatar Jacky Lin
Browse files

update

parent d4a82aea
package main
import (
"os"
)
func save(net Network) {
h, err := os.Create("data/hweights.model")
defer h.Close()
if err == nil {
net.hiddenWeights.MarshalBinaryTo(h)
}
o, err := os.Create("data/oweights.model")
defer o.Close()
if err == nil {
net.outputWeights.MarshalBinaryTo(o)
}
}
// load a neural network from file
func load(net *Network) {
h, err := os.Open("data/hweights.model")
defer h.Close()
if err == nil {
net.hiddenWeights.Reset()
net.hiddenWeights.UnmarshalBinaryFrom(h)
}
o, err := os.Open("data/oweights.model")
defer o.Close()
if err == nil {
net.outputWeights.Reset()
net.outputWeights.UnmarshalBinaryFrom(o)
}
return
}
module main
go 1.16
require (
golang.org/x/exp v0.0.0-20210220032938-85be41e4509f // indirect
golang.org/x/tools v0.1.0 // indirect
gonum.org/v1/gonum v0.8.2 // indirect
)
dmitri.shuralyov.com/gpu/mtl v0.0.0-20201218220906-28db891af037/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190731235908-ec7cb31e5a56/go.mod h1:JhuoJpWY28nO4Vef9tZUw9qufEGTyX1+7lmHxV5q5G4=
golang.org/x/exp v0.0.0-20210220032938-85be41e4509f h1:GrkO5AtFUU9U/1f5ctbIBXtBGeSJbWwIYfIsTcFMaX4=
golang.org/x/exp v0.0.0-20210220032938-85be41e4509f/go.mod h1:I6l2HNBLBZEcrOoCpyKLdY2lHoRZ8lI4x60KMCQDft4=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE=
golang.org/x/mobile v0.0.0-20201217150744-e6ae53a27f4f/go.mod h1:skQtrUTUwhdJvXM/2KKJzY8pDgNr9I/FOMqDVRPBUS4=
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.1.1-0.20191209134235-331c550502dd/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.1-0.20200828183125-ce943fd02449/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200117012304-6edc0a871e69/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.0.0-20200207183749-b753a1ba74fa/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.1.0 h1:po9/4sTYwZU9lPhi1tOrb4hCv3qrhiQ77LZfGa2OjwY=
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.8.2 h1:CCXrcPKiGGotvnN6jfUsKk4rRqm7q09/YbKb5xCEvtM=
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
package main
import (
"fmt"
"math"
"gonum.org/v1/gonum/distuv"
"gonum.org/v1/gonum/mat"
)
/* -- Matrix Operation -- */
func dot(m, n mat.Matrix) mat.Matrix {
r, _ := m.Dims()
_, c := n.Dims()
o := mat.NewDense(r, c, nil)
o.Product(m, n)
return o
}
func apply(fn func(i, j int, v float64) float64, m mat.Matrix) mat.Matrix {
r, c := m.Dims()
o := mat.NewDense(r, c, nil)
o.Apply(fn, m)
return o
}
func scale(s float64, m mat.Matrix) mat.Matrix {
r, c := m.Dims()
o := mat.NewDense(r, c, nil)
o.Scale(s, m)
return o
}
func multiply(m, n mat.Matrix) mat.Matrix {
r, c := m.Dims()
o := mat.NewDense(r, c, nil)
o.MulElem(m, n)
return o
}
func add(m, n mat.Matrix) mat.Matrix {
r, c := m.Dims()
o := mat.NewDense(r, c, nil)
o.Add(m, n)
return o
}
func subtract(m, n mat.Matrix) mat.Matrix {
r, c := m.Dims()
o := mat.NewDense(r, c, nil)
o.Sub(m, n)
return o
}
func addScalar(i float64, m mat.Matrix) mat.Matrix {
r, c := m.Dims()
a := make([]float64, r*c)
for x := 0; x < r*c; x++ {
a[x] = i
}
n := mat.NewDense(r, c, a)
return add(m, n)
}
func randomArray(size int, v float64) (data []float64) {
dist := distuv.Uniform{
Min: -1 / math.Sqrt(v),
Max: 1 / math.Sqrt(v),
}
data = make([]float64, size)
for i := 0; i < size; i++ {
data[i] = dist.Rand()
}
return
}
/* -- Sigmoid function -- */
func sigmoid(r, c int, z float64) float64 {
return 1.0 / (1 + math.Exp(-1*z))
}
func sigmoidPrime(m mat.Matrix) mat.Matrix {
rows, _ := m.Dims()
o := make([]float64, rows)
for i := range o {
o[i] = 1
}
ones := mat.NewDense(rows, 1, o)
return multiply(m, subtract(ones, m)) // m * (1 - m)
}
// pretty print a Gonum matrix
func matrixPrint(X mat.Matrix) {
fa := mat.Formatted(X, mat.Prefix(""), mat.Squeeze())
fmt.Printf("%v\n", fa)
}
package main
import (
"gonum.org/v1/gonum/mat"
)
type Network struct {
inputs int
hiddens int
outputs int
hiddenWeights *mat.Dense
outputWeights *mat.Dense
learningRate float64
}
func CreateNetwork(input, hidden, output int, rate float64) (net Network) {
net = Network{
inputs: input,
hiddens: hidden,
outputs: output,
learningRate: rate,
}
net.hiddenWeights = mat.NewDense(net.hiddens, net.inputs, randomArray(net.inputs*net.hiddens, float64(net.inputs)))
net.outputWeights = mat.NewDense(net.outputs, net.hiddens, randomArray(net.hiddens*net.outputs, float64(net.hiddens)))
return
}
func (net Network) FeedForward(inputData []float64) mat.Matrix {
// forward propagation
inputs := mat.NewDense(len(inputData), 1, inputData)
hiddenInputs := dot(net.hiddenWeights, inputs)
hiddenOutputs := apply(sigmoid, hiddenInputs)
finalInputs := dot(net.outputWeights, hiddenOutputs)
finalOutputs := apply(sigmoid, finalInputs)
return finalOutputs
}
// Train the neural network
func (net *Network) Train(inputData []float64, targetData []float64) {
// feedforward
inputs := mat.NewDense(len(inputData), 1, inputData)
hiddenInputs := dot(net.hiddenWeights, inputs)
hiddenOutputs := apply(sigmoid, hiddenInputs)
finalInputs := dot(net.outputWeights, hiddenOutputs)
finalOutputs := apply(sigmoid, finalInputs)
// find errors
targets := mat.NewDense(len(targetData), 1, targetData)
outputErrors := subtract(targets, finalOutputs)
hiddenErrors := dot(net.outputWeights.T(), outputErrors)
// backpropagate
net.outputWeights = add(net.outputWeights,
scale(net.learningRate,
dot(multiply(outputErrors, sigmoidPrime(finalOutputs)),
hiddenOutputs.T()))).(*mat.Dense)
net.hiddenWeights = add(net.hiddenWeights,
scale(net.learningRate,
dot(multiply(hiddenErrors, sigmoidPrime(hiddenOutputs)),
inputs.T()))).(*mat.Dense)
}
package digit_recognition
package main
import (
"fmt"
)
func main() {
s := []int{2, 2, 3}
nw, err := NewNetwork(s)
if err == nil {
fmt.Errorf("fail to create network")
}
fmt.Println(nw.Bias)
}
package digit_recognition
package main
import (
"fmt"
......@@ -7,22 +7,27 @@ import (
)
type Network struct {
numLayer int
size []int
bias []float64
weight []float64
NumLayer int
Sizes []int
Bias []float64
Weight []float64
}
func NewNetwork(size []int) (*Network, error) {
func NewNetwork(sizes []int) (*Network, error) {
if len(size) < 2 {
return nil, fmt.Errorf("must have at least two layer of network")
}
Bias := [][]float64{}
for _, size := range sizes {
}
n := &Network{
numLayer: len(size),
size: size,
bias: randomArray(len(size)),
weight: randomArray(len(size)),
NumLayer: len(sizes),
Sizes: sizes,
Bias: randomArray(len(sizes)),
Weight: randomArray(len(sizes)),
}
return n, nil
......@@ -36,13 +41,10 @@ func randomArray(size int) []float64 {
return arr
}
/*
func (n *Network) feedForward(in []int) {
for _, b := range n.bias {
for _, w := range n.weight {
in = mat.dot
}
}
}
return nil
}*/
// Basic Function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment