Skip to content

Commit

Permalink
fix mnist-test
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin committed Nov 21, 2017
1 parent 9fc72da commit 0ba8f69
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -586,20 +586,22 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra
nothing
end

function save_checkpoint(self :: FeedForward, prefix :: AbstractString, state :: OptimizationState)
save_checkpoint(self::FeedForward, prefix::AbstractString, state::OptimizationState) =
save_checkpoint(self.arch, self.arg_params, self.aux_params, prefix, state.curr_epoch)
end
function save_checkpoint(sym :: SymbolicNode, arg_params :: Dict{Base.Symbol, NDArray},
aux_params :: Dict{Base.Symbol, NDArray}, prefix :: AbstractString, epoch :: Int)

function save_checkpoint(sym::SymbolicNode, arg_params::Dict{Symbol},
aux_params::Dict{Symbol}, prefix::AbstractString, epoch::Int)
save("$prefix-symbol.json", sym)
save_dict = merge(Dict{Base.Symbol, NDArray}(map((x) -> Symbol("arg:$(x[1])") => x[2], arg_params)),
Dict{Base.Symbol, NDArray}(map((x) -> Symbol("aux:$(x[1])") => x[2], aux_params)))
save_dict = Dict{Symbol, NDArray}(map((x) -> Symbol("arg:$(x[1])") => x[2], arg_params))
if !isempty(aux_params)
merge!(save_dict, Dict(map((x) -> Symbol("aux:$(x[1])") => x[2], aux_params)))
end
save_filename = format("{1}-{2:04d}.params", prefix, epoch)
save(save_filename, save_dict)
info("Saved checkpoint to '$save_filename'")
end

function load_checkpoint(prefix :: AbstractString, epoch :: Int)
function load_checkpoint(prefix::AbstractString, epoch::Int)
arch = load("$prefix-symbol.json", SymbolicNode)
saved_dict = load(format("{1}-{2:04d}.params", prefix, epoch), NDArray)
arg_params = Dict{Base.Symbol, NDArray}()
Expand Down

0 comments on commit 0ba8f69

Please sign in to comment.