Skip to content

Commit

Permalink
Update for 0.4.6
Browse files Browse the repository at this point in the history
  • Loading branch information
SkyWorld117 committed Jun 6, 2021
1 parent df61ea4 commit 3caebae
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 43 deletions.
24 changes: 9 additions & 15 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"

[[ArrayInterface]]
deps = ["IfElse", "LinearAlgebra", "Requires", "SparseArrays", "Static"]
git-tree-sha1 = "c1878c1fa6342a703c791acf6916c2cd3a3aeab6"
git-tree-sha1 = "045ff5e1bc8c6fb1ecb28694abba0a0d55b5f4f5"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.16"
version = "3.1.17"

[[Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
Expand All @@ -33,12 +33,6 @@ git-tree-sha1 = "e747dac84f39c62aff6956651ec359686490134e"
uuid = "0b7ba130-8d10-5ba8-a3d6-c5182647fed9"
version = "1.21.0+0"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "5d64be50ea9b43a89b476be773e125cef03c7cd5"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.1"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
Expand Down Expand Up @@ -133,10 +127,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[LoopVectorization]]
deps = ["ArrayInterface", "ChainRulesCore", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "Polyester", "Requires", "SLEEFPirates", "Static", "StrideArraysCore", "ThreadingUtilities", "UnPack", "VectorizationBase"]
git-tree-sha1 = "a52a8eac5f1ddae7fb7e0ac4ef71f91f346648dc"
deps = ["ArrayInterface", "DocStringExtensions", "IfElse", "LinearAlgebra", "OffsetArrays", "Polyester", "Requires", "SLEEFPirates", "Static", "StrideArraysCore", "ThreadingUtilities", "UnPack", "VectorizationBase"]
git-tree-sha1 = "7d4237e46e44871d7ad39fab4e38d1e98df3f5c8"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.12.30"
version = "0.12.34"

[[Lz4_jll]]
deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
Expand Down Expand Up @@ -242,9 +236,9 @@ uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StrideArraysCore]]
deps = ["ArrayInterface", "Requires", "ThreadingUtilities", "VectorizationBase"]
git-tree-sha1 = "42491616950994149c6abfa960340745fae309d1"
git-tree-sha1 = "efcdfcbb8cf91e859f61011de1621be34b550e69"
uuid = "7792a7ef-975c-4747-a70f-980b88e8d1da"
version = "0.1.11"
version = "0.1.13"

[[TOML]]
deps = ["Dates"]
Expand Down Expand Up @@ -278,9 +272,9 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[VectorizationBase]]
deps = ["ArrayInterface", "Hwloc", "IfElse", "Libdl", "LinearAlgebra", "Static"]
git-tree-sha1 = "aac79bdce31bdca6374146d1df3897b057281020"
git-tree-sha1 = "7c8974c7de377a2dc67e778017c78f96fc8f0fc6"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.20.13"
version = "0.20.16"

[[VectorizedRNG]]
deps = ["Distributed", "Random", "UnPack", "VectorizationBase"]
Expand Down
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DianoiaML"
uuid = "0bf644f3-9992-48f5-8d6e-b7168dcf6b06"
authors = ["Yi Zhu, @SkyWord117"]
version = "0.4.5"
version = "0.4.6"

[deps]
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
Expand All @@ -11,8 +11,8 @@ Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
VectorizedRNG = "33b4df10-0173-11e9-2a0c-851a7edac40e"

[compat]
julia = "1.6.1"
HDF5 = "0.15.5"
LoopVectorization = "0.12.30"
LoopVectorization = "0.12.34"
Polyester = "0.3.1"
VectorizedRNG = "0.2.11"
VectorizedRNG = "0.2.11"
julia = "1.6.1"
4 changes: 4 additions & 0 deletions UPDATES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Update Log

### Update 0.4.6 - 06.06.2021
- Fixed **Minibatch_GD** and **GA**.
- Slightly Optimized **GA**.

### Update 0.4.5 - 06.05.2021
- Added **ResNet** as a network type.
- Added **Residual** as a layer to load **ResNet**.
Expand Down
42 changes: 20 additions & 22 deletions src/optimizer/ga.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module GA
using Polyester, LoopVectorization
using Polyester, LoopVectorization, VectorizedRNG

function fit(;models::Array, input_data::Array{Float32}, output_data::Array{Float32}, loss_function::Any, monitor::Any, α::Float64=0.01, num_copy::Int64, epochs::Int64=20, batch::Real=32, mini_batch::Int64=5)
function fit(;models::Array, input_data::Array{Float32}, output_data::Array{Float32}, monitor::Any, α::Float64=0.01, num_copy::Int64, epochs::Int64=20, batch::Real=32, mini_batch::Int64=5)
gene_pool = length(models)
batch_size = ceil(Int64, size(input_data)[end]/batch)*mini_batch
input_shape = models[1].layers[2].input_shape
output_shape = models[1].layers[end].output_shape
examples = size(input_data)[end]
batch_size = ceil(Int64, examples/batch)*mini_batch
current_input_data = zeros(Float32, input_shape..., mini_batch)
current_output_data = zeros(Float32, output_shape..., mini_batch)

Expand All @@ -15,11 +18,12 @@ module GA
print("Epoch ", e, "\n[")
loss = 0
losses = zeros(Float32, gene_pool)
weights = zeros(Float32, gene_pool)
@time begin

for t in 1:mini_batch:batch_size-mini_batch+1
@batch for i in 1:mini_batch
index = rand(1:size(input_data)[end])
index = rand(1:examples)
selectdim(current_input_data, length(input_shape)+1, i) .= selectdim(input_data, length(input_shape)+1, index)
selectdim(current_output_data, length(output_shape)+1, i) .= selectdim(output_data, length(output_shape)+1, index)
end
Expand All @@ -28,14 +32,14 @@ module GA
print("=")
end

for i in 1:gene_pool
@time for i in 1:gene_pool
models[i].activate(models[i], current_input_data)
losses[i] = monitor.func(models[i].layers[end-1].output, current_output_data)
end

for i in 1:gene_pool-num_copy
#recomutation!(models[argmax(losses)], models[rand(1:gene_pool)], models[rand(1:gene_pool)], α, t, batch_size-mini_batch+1)
recomutation!(models[argmax(losses)], models[sample(losses)], models[sample(losses)], α, t, batch_size-mini_batch+1)
@time for i in 1:gene_pool-num_copy
get_weights!(weights, losses)
recomutation!(models[argmax(losses)], models[sample(weights)], models[sample(weights)], α, t, batch_size-mini_batch+1)
losses[argmax(losses)] = Inf32
end

Expand All @@ -46,23 +50,17 @@ module GA
end
end

function sample(losses)
weights = Array{Float32, 1}(undef, length(losses))
s = 0.0f0
for i in 1:length(losses)
if losses[i] != Inf32
s += losses[i]
end
end
for i in 1:length(losses)
weights[i] = s/losses[i]
function get_weights!(weights::Array{Float32}, losses::Array{Float32})
@avxt for i in eachindex(losses)
weights[i] = ifelse(losses[i]!=Inf32, 1/losses[i], 0.0f0)
end
s = sum(weights)
@avx for i in 1:length(losses)
weights[i] /= s
end
@avxt weights ./= s
end

function sample(weights)
r = rand()
for i in 1:length(losses)
for i in 1:length(weights)
if weights[i]>=r
return i
else
Expand Down
5 changes: 3 additions & 2 deletions src/optimizer/minibatch_gd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ module Minibatch_GD
model.initialize(model, mini_batch)
input_shape = size(input_data)[1:end-1]
output_shape = size(output_data)[1:end-1]
batch_size = ceil(Int64, size(input_data)[end]/batch)*mini_batch
examples = size(input_data)[end]
batch_size = ceil(Int64, examples/batch)*mini_batch

current_input_data = zeros(Float32, input_shape..., mini_batch)
current_output_data = zeros(Float32, output_shape..., mini_batch)
Expand All @@ -15,7 +16,7 @@ module Minibatch_GD
@time begin
for t in 1:mini_batch:batch_size-mini_batch+1
@batch for i in 1:mini_batch
index = rand(1:size(input_data)[end])
index = rand(1:examples)
selectdim(current_input_data, length(input_shape)+1, i) .= selectdim(input_data, length(input_shape)+1, index)
selectdim(current_output_data, length(output_shape)+1, i) .= selectdim(output_data, length(output_shape)+1, index)
end
Expand Down

0 comments on commit 3caebae

Please sign in to comment.