Skip to content

Commit

Permalink
Merge pull request apache#92 from vchuravy/vc/fixup
Browse files Browse the repository at this point in the history
add support for ListAttrShallow
  • Loading branch information
pluskid committed May 9, 2016
2 parents 2b0d825 + 463f202 commit e7cd135
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra

# get grad attribute to allow for freezing
freeze_names = Symbol[]
for (attr, value) in list_attr(self.arch)
for (attr, value) in list_all_attr(self.arch)
sattr = string(attr)
if endswith(sattr, "grad") && value == "freeze"
push!(freeze_names, symbol(sattr[1:end-5]))
Expand Down
24 changes: 23 additions & 1 deletion src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,32 @@ end
#=doc
.. function: list_attr(self :: SymbolicNode)
Get all attributes from symbol.
Get all attributes from a symbol.
:return: Dictionary of attributes.
=#
function list_attr(self :: SymbolicNode)
ref_sz = Ref{MX_uint}(0)
ref_strings = Ref{char_pp}(0)
@mxcall(:MXSymbolListAttrShallow, (MX_handle, Ref{MX_uint}, Ref{char_pp}),
self, ref_sz, ref_strings)
narg = 2*ref_sz[]
strings = pointer_to_array(ref_strings[], narg)
out = Dict{Symbol, ByteString}()
for i in 1:2:narg
key = symbol(bytestring(strings[i]))
value = bytestring(strings[i+1])
out[key] = value
end
return out
end

#=doc
.. function: list_all_attr(self :: SymbolicNode)
Get all attributes from the symbol graph.
:return: Dictionary of attributes.
=#
function list_all_attr(self :: SymbolicNode)
ref_sz = Ref{MX_uint}(0)
ref_strings = Ref{char_pp}(0)
@mxcall(:MXSymbolListAttr, (MX_handle, Ref{MX_uint}, Ref{char_pp}),
Expand Down

0 comments on commit e7cd135

Please sign in to comment.