From 463f202e6fa4ad6eaaa2f82e93ddd2d4ec782981 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 9 May 2016 11:52:24 +0900 Subject: [PATCH] add support for ListAttrShallow --- src/model.jl | 2 +- src/symbolic-node.jl | 24 +++++++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index c025dc17091c..3984eb9d1389 100644 --- a/src/model.jl +++ b/src/model.jl @@ -371,7 +371,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra # get grad attribute to allow for freezing freeze_names = Symbol[] - for (attr, value) in list_attr(self.arch) + for (attr, value) in list_all_attr(self.arch) sattr = string(attr) if endswith(sattr, "grad") && value == "freeze" push!(freeze_names, symbol(sattr[1:end-5])) diff --git a/src/symbolic-node.jl b/src/symbolic-node.jl index 0bc3b593da8a..dcaae9bd5dc6 100644 --- a/src/symbolic-node.jl +++ b/src/symbolic-node.jl @@ -144,10 +144,32 @@ end #=doc .. function: list_attr(self :: SymbolicNode) - Get all attributes from symbol. + Get all attributes from a symbol. :return: Dictionary of attributes. =# function list_attr(self :: SymbolicNode) + ref_sz = Ref{MX_uint}(0) + ref_strings = Ref{char_pp}(0) + @mxcall(:MXSymbolListAttrShallow, (MX_handle, Ref{MX_uint}, Ref{char_pp}), + self, ref_sz, ref_strings) + narg = 2*ref_sz[] + strings = pointer_to_array(ref_strings[], narg) + out = Dict{Symbol, ByteString}() + for i in 1:2:narg + key = symbol(bytestring(strings[i])) + value = bytestring(strings[i+1]) + out[key] = value + end + return out +end + +#=doc +.. function: list_all_attr(self :: SymbolicNode) + + Get all attributes from the symbol graph. + :return: Dictionary of attributes. +=# +function list_all_attr(self :: SymbolicNode) ref_sz = Ref{MX_uint}(0) ref_strings = Ref{char_pp}(0) @mxcall(:MXSymbolListAttr, (MX_handle, Ref{MX_uint}, Ref{char_pp}),