Skip to content

Commit

Permalink
Merge pull request apache#87 from oist/vc/finetuning
Browse files Browse the repository at this point in the history
[RFC] Freeze layers
  • Loading branch information
pluskid committed Apr 26, 2016
2 parents 539a070 + 68180f7 commit 5ba27eb
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 12 deletions.
29 changes: 19 additions & 10 deletions src/executor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,22 +128,31 @@ function bind(self :: SymbolicNode; kwargs...)
bind(self, context, args; kwargs...)
end

function simple_bind(self :: SymbolicNode, ctx :: Context; grad_req :: GRAD_REQ=GRAD_WRITE, kwargs...)
function simple_bind(self :: SymbolicNode, ctx :: Context;
grad_req :: Union{GRAD_REQ, Dict{Symbol, GRAD_REQ}}=GRAD_WRITE,
kwargs...)
arg_shapes, out_shapes, aux_shapes = infer_shape(self; kwargs...)
@assert(!isa(arg_shapes, Void), "Information not enough to perform complete shape inference")

arg_arrays = NDArray[zeros(shape, ctx) for shape in arg_shapes]
arg_names = list_arguments(self)
if grad_req == GRAD_NOP
grad_arrays = Dict{Base.Symbol,NDArray}()
else

grad_arrays = Dict{Symbol,NDArray}()

if grad_req != GRAD_NOP
shapes = zip(arg_names, arg_shapes)

# if not in provided data, should be parameters
provided_data_names = [x[1] for x in kwargs]
grad_arrays = Dict{Base.Symbol,NDArray}()
for (name, shape) in zip(arg_names, arg_shapes)
# if not in provided data, should be parameters
if !in(name, provided_data_names)
grad_arrays[name] = zeros(shape, ctx)
end
shapes = filter(x -> !in(x[1], provided_data_names), shapes)

# Remove all gradients for nop params
# if isa(grad_req, Dict{Symbol, GRAD_REQ})
# shapes = filter(x -> grad_req[x[1]] != GRAD_NOP,shapes)
# end

for (name, shape) in shapes
grad_arrays[name] = zeros(shape, ctx)
end
end

Expand Down
27 changes: 26 additions & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,11 +369,32 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra
kvstore, update_on_kvstore = _create_kvstore(kvstore, length(self.ctx), self.arg_params)
end

# get grad attribute to allow for freezing
freeze_names = Symbol[]
for (attr, value) in list_attr(self.arch)
sattr = string(attr)
if endswith(sattr, "grad") && value == "freeze"
push!(freeze_names, symbol(sattr[1:end-5]))
end
end
# Needs to correspond to the correct id in the update loop layer idx=1:length(param_names).
freeze_idx = filter(i -> in(param_names[i], freeze_names), 1:length(param_names))

# Setup grad_req as a dictionary
grad_req = Dict{Symbol, GRAD_REQ}()
for param in param_names
if in(param, freeze_names)
grad_req[param] = GRAD_NOP
else
grad_req[param] = GRAD_WRITE
end
end

train_execs = Array(Executor, num_dev)
for i = 1:num_dev
data_shapes = [k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in provide_data(data)]
label_shapes = [k => tuple(v[1:end-1]...,length(slices[i])) for (k,v) in provide_label(data)]
train_execs[i] = simple_bind(self.arch, self.ctx[i]; grad_req=GRAD_WRITE, data_shapes..., label_shapes...)
train_execs[i] = simple_bind(self.arch, self.ctx[i]; grad_req=grad_req, data_shapes..., label_shapes...)
dbg_str = mx.debug_str(train_execs[i])
info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[i]))

Expand Down Expand Up @@ -463,6 +484,10 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra

# update parameters
for idx = 1:length(param_names)
if in(idx, freeze_idx)
continue # Skip parameter update entirely
end

# gradient synchronization
if !isa(kvstore, Void)
# push gradient, priority is negative index
Expand Down
25 changes: 24 additions & 1 deletion src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,37 @@ function get_attr(self :: SymbolicNode, key :: Symbol)
key_s = bytestring(string(key))
ref_out = Ref{Cstring}()
ref_success = Ref{Cint}(-1)
@mxcall(:MXSymbolGetAttr, (MX_handle, Cstring, Ref{Cstring}, Ref{Cint}), self, key_s, ref_out, ref_success)
@mxcall(:MXSymbolGetAttr, (MX_handle, Cstring, Ref{Cstring}, Ref{Cint}),
self, key_s, ref_out, ref_success)
if ref_success[] == 1
return Nullable{ByteString}(bytestring(ref_out[]))
else
return Nullable{ByteString}()
end
end

#=doc
.. function: list_attr(self :: SymbolicNode)
Get all attributes from symbol.
:return: Dictionary of attributes.
=#
function list_attr(self :: SymbolicNode)
ref_sz = Ref{MX_uint}(0)
ref_strings = Ref{char_pp}(0)
@mxcall(:MXSymbolListAttr, (MX_handle, Ref{MX_uint}, Ref{char_pp}),
self, ref_sz, ref_strings)
narg = 2*ref_sz[]
strings = pointer_to_array(ref_strings[], narg)
out = Dict{Symbol, ByteString}()
for i in 1:2:narg
key = symbol(bytestring(strings[i]))
value = bytestring(strings[i+1])
out[key] = value
end
return out
end

#=doc
.. function:: set_attr(self:: SymbolicNode, key :: Symbol, value :: AbstractString)
Expand Down
1 change: 1 addition & 0 deletions test/unittest/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ function test_attrs()
@test isnull(mx.get_attr(conv, :b))
@test get(mx.get_attr(conv, :a)) == "a"
@test get(mx.get_attr(conv, )) == "π"
@test mx.list_attr(conv) == Dict(:a => "a", => "π")

@test_throws MethodError mx.Variable(:data3, attrs = Dict(:test => "1.0", :test2 => 1.0))
@test_throws MethodError mx.Convolution(data=data2, kernel = (1,1), num_filter = 1, attrs = Dict(:test => "1.0", :test2 => 1.0))
Expand Down

0 comments on commit 5ba27eb

Please sign in to comment.