Skip to content

Commit

Permalink
Merge pull request apache#98 from oist/vc/0.5
Browse files Browse the repository at this point in the history
Fix deprecation for v0.5 and adapt macro to expr changes
  • Loading branch information
pluskid committed May 27, 2016
2 parents 0ea4369 + afd9baa commit d1bf894
Show file tree
Hide file tree
Showing 18 changed files with 86 additions and 74 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ os:
- osx
julia:
- 0.4
#- nightly
- nightly

# dependent apt packages
addons:
Expand Down
3 changes: 2 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
julia 0.4
julia 0.4+
Compat
Formatting
BinDeps
JSON
4 changes: 4 additions & 0 deletions src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ module MXNet
# functions with the same names as built-in utilities like "zeros", etc.
export mx
module mx

using Compat
import Compat.String

using Formatting

# Functions from base that we can safely extend and that are defined by libmxnet.
Expand Down
15 changes: 12 additions & 3 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function mx_get_last_error()
if msg == C_NULL
throw(MXError("Failed to get last error message"))
end
return bytestring(msg)
return @compat String(msg)
end

"Utility macro to call MXNet API functions"
Expand Down Expand Up @@ -189,10 +189,19 @@ function _defstruct_impl(is_immutable, name, fields)
if isa(name, Symbol)
name = esc(name)
super_name = :Any
elseif VERSION >= v"0.5-"
@assert(isa(name, Expr) && name.head == :(<:) && length(name.args) == 2 &&
isa(name.args[1], Symbol) && isa(name.args[2], Symbol),
"name must be of form 'Name <: SuperType'")

super_name = esc(name.args[2])
name = esc(name.args[1])
else
@assert(isa(name, Expr) && name.head == :comparison && length(name.args) == 3 && name.args[2] == :(<:),
@assert(isa(name, Expr) && name.head == :comparison &&
length(name.args) == 3 && name.args[2] == :(<:) &&
isa(name.args[1], Symbol) && isa(name.args[3], Symbol),
"name must be of form 'Name <: SuperType'")
@assert(isa(name.args[1], Symbol) && isa(name.args[3], Symbol))

super_name = esc(name.args[3])
name = esc(name.args[1])
end
Expand Down
2 changes: 1 addition & 1 deletion src/executor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,5 @@ Can be used to get an estimated about the memory cost.
function debug_str(self :: Executor)
s_ref = Ref{Cstring}()
@mxcall(:MXExecutorPrint, (MX_handle, Ptr{Cstring}), self.handle, s_ref)
bytestring(s_ref[])
@compat String(s_ref[])
end
12 changes: 5 additions & 7 deletions src/initializer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,6 @@ function _init_weight(self :: XavierInitializer, name :: Base.Symbol, array :: N
fan_in = prod(dims[2:end])
fan_out = dims[1]

if self.distribution == xv_uniform
func(σ, data) = rand!(-σ, σ, data)
elseif self.distribution == xv_normal
func(σ, data) = randn!(0.0, σ, data)
end

if self.regularization == xv_avg
factor = (fan_in + fan_out) / 2
elseif self.regularization == xv_in
Expand All @@ -154,5 +148,9 @@ function _init_weight(self :: XavierInitializer, name :: Base.Symbol, array :: N

σ = (self.magnitude / factor)

func(σ, array)
if self.distribution == xv_uniform
rand!(-σ, σ, array)
elseif self.distribution == xv_normal
randn!(0.0, σ, array)
end
end
6 changes: 3 additions & 3 deletions src/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -584,15 +584,15 @@ function _define_data_iter_creator(hdr :: MX_handle; gen_docs::Bool=false)
(MX_handle, Ref{char_p}, Ref{char_p}, Ref{MX_uint}, Ref{char_pp}, Ref{char_pp}, Ref{char_pp}),
hdr, ref_name, ref_desc, ref_narg, ref_arg_names, ref_arg_types, ref_arg_descs)

iter_name = symbol(bytestring(ref_name[]))
iter_name = Symbol(String(ref_name[]))

if gen_docs
if endswith(string(iter_name), "Iter")
f_desc = "Can also be called with the alias ``$(string(iter_name)[1:end-4] * "Provider")``.\n"
else
f_desc = ""
end
f_desc *= bytestring(ref_desc[]) * "\n\n"
f_desc *= String(ref_desc[]) * "\n\n"
f_desc *= ":param Base.Symbol data_name: keyword argument, default ``:data``. The name of the data.\n"
f_desc *= ":param Base.Symbol label_name: keyword argument, default ``:softmax_label``. " *
"The name of the label. Could be ``nothing`` if no label is presented in this dataset.\n\n"
Expand All @@ -617,7 +617,7 @@ function _define_data_iter_creator(hdr :: MX_handle; gen_docs::Bool=false)

# add an alias XXXProvider => XXXIter
if endswith(string(iter_name), "Iter")
alias_name = symbol(string(iter_name)[1:end-4] * "Provider")
alias_name = Symbol(string(iter_name)[1:end-4] * "Provider")
eval(:($alias_name = $iter_name))
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/kvstore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ end
function get_type(self :: KVStore)
type_ref = Ref{char_p}(0)
@mxcall(:MXKVStoreGetType, (MX_handle, Ref{char_p}), self, type_ref)
return symbol(bytestring(type_ref[]))
return Symbol(@compat String(type_ref[]))
end

function get_num_workers(self :: KVStore)
Expand Down
2 changes: 1 addition & 1 deletion src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ type MultiACE <: AbstractEvalMetric
end

function get(metric :: MultiACE)
aces = [(symbol("ACE_$(i-0)"), - metric.aces[i] / metric.counts[i]) for i in 1:length(metric.aces)]
aces = [(Symbol("ACE_$(i-0)"), - metric.aces[i] / metric.counts[i]) for i in 1:length(metric.aces)]
push!(aces, (:ACE, - Base.sum(metric.aces) / Base.sum(metric.counts)))
return aces
end
Expand Down
8 changes: 4 additions & 4 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ function fit(self :: FeedForward, optimizer :: AbstractOptimizer, data :: Abstra
for (attr, value) in list_all_attr(self.arch)
sattr = string(attr)
if endswith(sattr, "grad") && value == "freeze"
push!(freeze_names, symbol(sattr[1:end-5]))
push!(freeze_names, Symbol(sattr[1:end-5]))
end
end
# Needs to correspond to the correct id in the update loop layer idx=1:length(param_names).
Expand Down Expand Up @@ -582,8 +582,8 @@ end
function save_checkpoint(sym :: SymbolicNode, arg_params :: Dict{Base.Symbol, NDArray},
aux_params :: Dict{Base.Symbol, NDArray}, prefix :: AbstractString, epoch :: Int)
save("$prefix-symbol.json", sym)
save_dict = merge(Dict([symbol("arg:$k") => v for (k,v) in arg_params]),
Dict([symbol("aux:$k") => v for (k,v) in aux_params]))
save_dict = merge(Dict([Symbol("arg:$k") => v for (k,v) in arg_params]),
Dict([Symbol("aux:$k") => v for (k,v) in aux_params]))
save_filename = format("{1}-{2:04d}.params", prefix, epoch)
save(save_filename, save_dict)
info("Saved checkpoint to '$save_filename'")
Expand All @@ -596,7 +596,7 @@ function load_checkpoint(prefix :: AbstractString, epoch :: Int)
aux_params = Dict{Base.Symbol, NDArray}()
for (k,v) in saved_dict
tp, name = split(string(k), ':')
name = symbol(name)
name = Symbol(name)
if tp == "arg"
arg_params[name] = v
else
Expand Down
10 changes: 5 additions & 5 deletions src/name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import Base: get!
# is automatically generated based on the hint string.
function _default_get_name!(counter :: NameCounter, name :: NameType, hint :: NameType)
if isa(name, Base.Symbol) || !isempty(name)
return symbol(name)
return Symbol(name)
end

hint = symbol(hint)
hint = Symbol(hint)
if !haskey(counter, hint)
counter[hint] = 0
end
name = symbol("$hint$(counter[hint])")
name = Symbol("$hint$(counter[hint])")
counter[hint] += 1
return name
end
Expand All @@ -34,11 +34,11 @@ type PrefixNameManager <: AbstractNameManager
prefix :: Base.Symbol
counter :: NameCounter
end
PrefixNameManager(prefix :: NameType) = PrefixNameManager(symbol(prefix), NameCounter())
PrefixNameManager(prefix :: NameType) = PrefixNameManager(Symbol(prefix), NameCounter())

function get!(manager :: PrefixNameManager, name :: NameType, hint :: NameType)
name = _default_get_name!(manager.counter, name, hint)
return symbol("$(manager.prefix)$name")
return Symbol("$(manager.prefix)$name")
end

DEFAULT_NAME_MANAGER = BasicNameManager()
26 changes: 13 additions & 13 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ function load(filename::AbstractString, ::Type{NDArray})
return [NDArray(MX_NDArrayHandle(hdr)) for hdr in pointer_to_array(out_hdrs[], out_size)]
else
@assert out_size == out_name_size
return Dict([(symbol(bytestring(k)), NDArray(MX_NDArrayHandle(hdr))) for (k,hdr) in
return Dict([(Symbol(@compat String(k)), NDArray(MX_NDArrayHandle(hdr))) for (k,hdr) in
zip(pointer_to_array(out_names[], out_size), pointer_to_array(out_hdrs[], out_size))])
end
end
Expand Down Expand Up @@ -903,11 +903,11 @@ function _import_ndarray_functions(;gen_docs=false)
func_handle, ref_name, ref_desc, ref_narg, ref_arg_names,
ref_arg_types, ref_arg_descs, ref_ret_type)

func_name = symbol(bytestring(ref_name[]))
func_name = Symbol(@compat String(ref_name[]))

if gen_docs
# generate document only
f_desc = bytestring(ref_desc[]) * "\n\n"
f_desc = @compat String(ref_desc[]) * "\n\n"
f_desc *= _format_docstring(Int(ref_narg[]), ref_arg_names, ref_arg_types, ref_arg_descs)
docs[func_name] = f_desc
else
Expand All @@ -932,18 +932,18 @@ function _import_ndarray_functions(;gen_docs=false)

# general ndarray function
if arg_before_scalar
args = vcat([Expr(:(::), symbol("in$i"), NDArray) for i=1:n_used_vars],
[Expr(:(::), symbol("sca$i"), Real) for i=1:n_scalars],
[Expr(:(::), symbol("out$i"), NDArray) for i=1:n_mutate_vars])
args = vcat([Expr(:(::), Symbol("in$i"), NDArray) for i=1:n_used_vars],
[Expr(:(::), Symbol("sca$i"), Real) for i=1:n_scalars],
[Expr(:(::), Symbol("out$i"), NDArray) for i=1:n_mutate_vars])
else
args = vcat([Expr(:(::), symbol("sca$i"), Real) for i=1:n_scalars],
[Expr(:(::), symbol("in$i"), NDArray) for i=1:n_used_vars],
[Expr(:(::), symbol("out$i"), NDArray) for i=1:n_mutate_vars])
args = vcat([Expr(:(::), Symbol("sca$i"), Real) for i=1:n_scalars],
[Expr(:(::), Symbol("in$i"), NDArray) for i=1:n_used_vars],
[Expr(:(::), Symbol("out$i"), NDArray) for i=1:n_mutate_vars])
end

_use_vars = Expr(:ref, :MX_handle, [symbol("in$i") for i=1:n_used_vars]...)
_scalars = Expr(:ref, :MX_float, [symbol("sca$i") for i=1:n_scalars]...)
_mut_vars = Expr(:ref, :MX_handle, [symbol("out$i") for i=1:n_mutate_vars]...)
_use_vars = Expr(:ref, :MX_handle, [Symbol("in$i") for i=1:n_used_vars]...)
_scalars = Expr(:ref, :MX_float, [Symbol("sca$i") for i=1:n_scalars]...)
_mut_vars = Expr(:ref, :MX_handle, [Symbol("out$i") for i=1:n_mutate_vars]...)

# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
Expand All @@ -955,7 +955,7 @@ function _import_ndarray_functions(;gen_docs=false)
if n_mutate_vars == 1
stmt_ret = :(return out1)
else
stmt_ret = Expr(:return, Expr(:tuple, [symbol("out$i") for i=1:n_mutate_vars]...))
stmt_ret = Expr(:return, Expr(:tuple, [Symbol("out$i") for i=1:n_mutate_vars]...))
end

func_body = Expr(:block, stmt_call, stmt_ret)
Expand Down
4 changes: 2 additions & 2 deletions src/nn-factory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ function MLP(input, spec; hidden_activation::Base.Symbol=:relu, prefix=gensym())
n_unit = s
act_type = hidden_activation
end
input = FullyConnected(input, name=symbol(prefix, "fc$i"), num_hidden=n_unit)
input = FullyConnected(input, name=Symbol(prefix, "fc$i"), num_hidden=n_unit)
if i < n_layer || isa(s, Tuple)
# will not add activation unless the user explicitly specified
input = Activation(input, name=symbol(prefix, "$act_type$i"), act_type=act_type)
input = Activation(input, name=Symbol(prefix, "$act_type$i"), act_type=act_type)
end
end

Expand Down
44 changes: 22 additions & 22 deletions src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ macro _list_symbol_info(self, func_name)
$self, ref_sz, ref_names)
narg = ref_sz[]
names = pointer_to_array(ref_names[], narg)
names = [symbol(bytestring(x)) for x in names]
names = [Symbol(@compat String(x)) for x in names]
return names
end
end
Expand Down Expand Up @@ -129,15 +129,15 @@ end
:return: The value belonging to key as a :class:`Nullable`.
=#
function get_attr(self :: SymbolicNode, key :: Symbol)
key_s = bytestring(string(key))
key_s = @compat String(string(key))
ref_out = Ref{Cstring}()
ref_success = Ref{Cint}(-1)
@mxcall(:MXSymbolGetAttr, (MX_handle, Cstring, Ref{Cstring}, Ref{Cint}),
self, key_s, ref_out, ref_success)
if ref_success[] == 1
return Nullable{ByteString}(bytestring(ref_out[]))
return Nullable{String}(@compat String(ref_out[]))
else
return Nullable{ByteString}()
return Nullable{String}()
end
end

Expand All @@ -154,10 +154,10 @@ function list_attr(self :: SymbolicNode)
self, ref_sz, ref_strings)
narg = 2*ref_sz[]
strings = pointer_to_array(ref_strings[], narg)
out = Dict{Symbol, ByteString}()
out = Dict{Symbol, String}()
for i in 1:2:narg
key = symbol(bytestring(strings[i]))
value = bytestring(strings[i+1])
key = Symbol(@compat String(strings[i]))
value = @compat String(strings[i+1])
out[key] = value
end
return out
Expand All @@ -176,10 +176,10 @@ function list_all_attr(self :: SymbolicNode)
self, ref_sz, ref_strings)
narg = 2*ref_sz[]
strings = pointer_to_array(ref_strings[], narg)
out = Dict{Symbol, ByteString}()
out = Dict{Symbol, String}()
for i in 1:2:narg
key = symbol(bytestring(strings[i]))
value = bytestring(strings[i+1])
key = Symbol(@compat String(strings[i]))
value = @compat String(strings[i+1])
out[key] = value
end
return out
Expand All @@ -198,8 +198,8 @@ end
cause unexpected behavior and inconsistency.
=#
function set_attr(self :: SymbolicNode, key :: Symbol, value :: AbstractString)
key_s = bytestring(string(key))
value_s = bytestring(value)
key_s = @compat String(string(key))
value_s = @compat String(value)

@mxcall(:MXSymbolSetAttr, (MX_handle, Cstring, Cstring), self, key_s, value_s)
end
Expand Down Expand Up @@ -325,7 +325,7 @@ end
indicating the index, as in the list of :func:`list_outputs`.
=#
function Base.getindex(self :: SymbolicNode, idx :: Union{Base.Symbol, AbstractString})
idx = symbol(idx)
idx = Symbol(idx)
i_idx = find(idx .== list_outputs(self))
@assert(length(i_idx) > 0, "Cannot find output with name '$idx'")
@assert(length(i_idx) < 2, "Found duplicated output with name '$idx'")
Expand Down Expand Up @@ -474,7 +474,7 @@ end
function to_json(self :: SymbolicNode)
ref_json = Ref{char_p}(0)
@mxcall(:MXSymbolSaveToJSON, (MX_handle, Ref{char_p}), self, ref_json)
return bytestring(ref_json[])
return @compat String(ref_json[])
end

#=doc
Expand Down Expand Up @@ -533,20 +533,20 @@ function _define_atomic_symbol_creator(hdr :: MX_handle; gen_docs=false)
hdr, ref_name, ref_desc, ref_nargs, ref_arg_names, ref_arg_types, ref_arg_descs,
ref_kv_nargs, ref_ret_type)

func_name_s= bytestring(ref_name[])
func_name = symbol(func_name_s)
kv_nargs_s = bytestring(ref_kv_nargs[])
kv_nargs = symbol(kv_nargs_s)
func_name_s= @compat String(ref_name[])
func_name = Symbol(func_name_s)
kv_nargs_s = @compat String(ref_kv_nargs[])
kv_nargs = Symbol(kv_nargs_s)

if gen_docs
f_desc = bytestring(ref_desc[]) * "\n\n"
f_desc = @compat String(ref_desc[]) * "\n\n"
if !isempty(kv_nargs_s)
f_desc *= "This function support variable length positional :class:`SymbolicNode` inputs.\n\n"
end
f_desc *= _format_docstring(Int(ref_nargs[]), ref_arg_names, ref_arg_types, ref_arg_descs)
f_desc *= ":param Symbol name: The name of the :class:`SymbolicNode`. (e.g. `:my_symbol`), optional.\n"
f_desc *= ":param Dict{Symbol, AbstractString} attrs: The attributes associated with this :class:`SymbolicNode`.\n\n"
f_desc *= ":return: $(_format_typestring(bytestring(ref_ret_type[]))).\n\n"
f_desc *= ":return: $(_format_typestring(@compat String(ref_ret_type[]))).\n\n"
return (func_name, f_desc)
end

Expand All @@ -565,7 +565,7 @@ function _define_atomic_symbol_creator(hdr :: MX_handle; gen_docs=false)
symbol_kws = Dict{Symbol, SymbolicNode}()
attrs = Dict{Symbol, AbstractString}()

$(if kv_nargs != symbol("")
$(if kv_nargs != Symbol("")
quote
if !in($kv_nargs_s, param_keys)
push!(param_keys, $kv_nargs_s)
Expand Down Expand Up @@ -593,7 +593,7 @@ function _define_atomic_symbol_creator(hdr :: MX_handle; gen_docs=false)
if length(args) != 0 && length(symbol_kws) != 0
@assert(false, $func_name_s * " only accepts Symbols either as positional or keyword arguments, not both.")
end
$(if kv_nargs != symbol("")
$(if kv_nargs != Symbol("")
quote
if length(symbol_kws) > 0
@assert(false, $func_name_s * " takes variable number of SymbolicNode arguments, " *
Expand Down
Loading

0 comments on commit d1bf894

Please sign in to comment.