Skip to content

Commit

Permalink
Merge pull request apache#86 from oist/vc/transfer_model
Browse files Browse the repository at this point in the history
Pre-seed model from smaller model.
  • Loading branch information
pluskid committed Apr 25, 2016
2 parents bbe3151 + 4c7a066 commit 539a070
Showing 1 changed file with 40 additions and 13 deletions.
53 changes: 40 additions & 13 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,32 +92,59 @@ function init_model(self :: FeedForward, initializer :: AbstractInitializer; ove
param_names = setdiff(arg_names, input_names)
aux_names = list_auxiliary_states(self.arch)

arg_defined = true
aux_defined = true

arg_shapes, out_shapes, aux_shapes = infer_shape(self.arch; input_shapes...)

# If target dict is not yet defined set a temporary one
if !isdefined(self, :arg_params)
param_name_shapes = filter(x -> in(x[1],param_names), zip(arg_names, arg_shapes))
self.arg_params = Dict([name => empty(shape) for (name,shape) in param_name_shapes])
arg_defined = false
self.arg_params = Dict{Symbol, NDArray}()
end
if !isdefined(self, :aux_params)
self.aux_params = Dict([name => empty(shape) for (name,shape) in zip(aux_names,aux_shapes)])
aux_defined = false
self.aux_params = Dict{Symbol, NDArray}()
end

arg_params = Dict{Symbol, NDArray}()
aux_params = Dict{Symbol, NDArray}()

for (name, shape) in filter(x -> in(x[1],param_names), zip(arg_names, arg_shapes))
if haskey(self.arg_params, name)
if shape == size(self.arg_params[name])
arg_params[name] = self.arg_params[name]
continue
else
warn("Shape mismatch for $name. Overwriting with new one.")
delete!(self.arg_params, name)
end
end
arg_params[name] = empty(shape)
end

# initialize the contents of the parameters
if !arg_defined || overwrite
for (k,v) in self.arg_params
for (name, shape) in zip(aux_names, aux_shapes)
if haskey(self.aux_params, name)
if shape == size(self.auxg_params[name])
aux_params[name] = self.aux_params[name]
continue
else
warn("Shape mismatch for $name. Overwriting with new one.")
delete!(self.aux_params, name)
end
end
aux_params[name] = empty(shape)
end

for (k,v) in arg_params
if overwrite || !haskey(self.arg_params, k)
init(initializer, k, v)
end
end
if !aux_defined || overwrite
for (k,v) in self.aux_params
for (k,v) in aux_params
if overwrite || !haskey(self.aux_params, k)
init(initializer, k, v)
end
end

self.arg_params = arg_params
self.aux_params = aux_params

return (arg_names, param_names, aux_names)
end

Expand Down

0 comments on commit 539a070

Please sign in to comment.