diff --git a/src/model.jl b/src/model.jl index b0004d800..51d7bd5da 100644 --- a/src/model.jl +++ b/src/model.jl @@ -586,20 +586,22 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra nothing end -function save_checkpoint(self :: FeedForward, prefix :: AbstractString, state :: OptimizationState) +save_checkpoint(self::FeedForward, prefix::AbstractString, state::OptimizationState) = save_checkpoint(self.arch, self.arg_params, self.aux_params, prefix, state.curr_epoch) -end -function save_checkpoint(sym :: SymbolicNode, arg_params :: Dict{Base.Symbol, NDArray}, - aux_params :: Dict{Base.Symbol, NDArray}, prefix :: AbstractString, epoch :: Int) + +function save_checkpoint(sym::SymbolicNode, arg_params::Dict{Symbol}, + aux_params::Dict{Symbol}, prefix::AbstractString, epoch::Int) save("$prefix-symbol.json", sym) - save_dict = merge(Dict{Base.Symbol, NDArray}(map((x) -> Symbol("arg:$(x[1])") => x[2], arg_params)), - Dict{Base.Symbol, NDArray}(map((x) -> Symbol("aux:$(x[1])") => x[2], aux_params))) + save_dict = Dict{Symbol, NDArray}(map((x) -> Symbol("arg:$(x[1])") => x[2], arg_params)) + if !isempty(aux_params) + merge!(save_dict, Dict(map((x) -> Symbol("aux:$(x[1])") => x[2], aux_params))) + end save_filename = format("{1}-{2:04d}.params", prefix, epoch) save(save_filename, save_dict) info("Saved checkpoint to '$save_filename'") end -function load_checkpoint(prefix :: AbstractString, epoch :: Int) +function load_checkpoint(prefix::AbstractString, epoch::Int) arch = load("$prefix-symbol.json", SymbolicNode) saved_dict = load(format("{1}-{2:04d}.params", prefix, epoch), NDArray) arg_params = Dict{Base.Symbol, NDArray}()