Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Base structure for Native Operators #16

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1bacb46
Base structure for Native Operators
vchuravy Nov 5, 2015
56525ec
make NativeOpInfo work with precompilation
vchuravy Nov 5, 2015
fd44835
Custom reimplementation of Base.SingleAsyncWork
vchuravy Nov 6, 2015
bcf06d1
add explanation for _Async
vchuravy Nov 6, 2015
a2133af
threadsafe handling for forward and backward
vchuravy Dec 12, 2015
0ca00af
Define pulic facing interface
vchuravy Dec 12, 2015
85845f3
conceptualise the interface for julia native ops
vchuravy Dec 12, 2015
bb222b2
use tasks to handle forward and backward
vchuravy Dec 13, 2015
19063ad
cleanup and comments
vchuravy Dec 13, 2015
5725066
adapt to NDArrayOp
vchuravy Jan 7, 2016
510de9e
handle forward and backward entry
vchuravy Jan 7, 2016
fea55b1
handle declare backward dependecy
vchuravy Jan 7, 2016
f059ff9
weed out surface bugs
vchuravy Jan 7, 2016
340a3f9
properly store outputs and arguments
vchuravy Jan 7, 2016
a20241c
initial test example for softmax
vchuravy Jan 8, 2016
91010a2
terminate list of strings with empty string to prevent seqfaults
vchuravy Jan 8, 2016
d26178a
Make initialization work
vchuravy Jan 8, 2016
f2612ce
use unsafe_load instead of creating an array for forward/backward
vchuravy Jan 8, 2016
e2f26c3
reverse shapes
vchuravy Jan 11, 2016
8ab553c
throw method error instead of predefining functions
vchuravy Jan 11, 2016
ca93d8d
use method signatures throughout
vchuravy Jan 11, 2016
1f394cb
use reverse!
vchuravy Jan 11, 2016
deac772
fix reversing shapes
vchuravy Jan 11, 2016
09b9d5e
use convert instead of typeassert on return values
vchuravy Jan 11, 2016
d872f02
fixup: import ..mx
vchuravy Jan 11, 2016
e4906bb
define call foe native operators
vchuravy Jan 11, 2016
0aad560
introduce WeakKeyDict for storage management and use types instead of…
vchuravy Feb 16, 2016
61fd0c8
ensure that received tags are correct
vchuravy Feb 18, 2016
651a9e9
fixup backward
vchuravy Feb 18, 2016
68ccb4b
use RawMutex.jl for backwards and forwards
vchuravy Aug 12, 2016
209c478
Move mutex to later point so that we are not racy for the barrier
vchuravy Aug 12, 2016
d3a2496
get closer to correct interaction between c++ threads and julia
vchuravy Aug 12, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions examples/julia-softmax.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
using MXNet

import MXNet.mx.Native: Operator, list_arguments, list_outputs, infer_shape,
forward, backward, need_top_grad, create_op
type JuliaSoftmax <: Operator end

(op::JuliaSoftmax)(;kwargs...) = create_op(op;kwargs...)

list_arguments(:: JuliaSoftmax) = ["data", "label"]
list_outputs(:: JuliaSoftmax) = ["output"]
need_top_grad(:: JuliaSoftmax) = false

function infer_shape(::JuliaSoftmax, in_shapes :: Vector{Vector{UInt32}})
data_shape = in_shapes[1]
label_shape = [last(data_shape)]
output_shape = data_shape
return (data_shape, label_shape), (output_shape, )
end

function forward(::JuliaSoftmax, in_data :: Vector{mx.NDArray}, out_data :: Vector{mx.NDArray})
info("Entering forward")
x = in_data[1]
y = out_data[1]

@mx.nd_as_jl ro=x rw=y begin
y[:] = exp(x - maximum(x, 1))
y /= sum(y, 1)
end
info("Leaving forward")
end

#TODO: Correct gradient
function backward(::JuliaSoftmax, out_grad :: Vector{mx.NDArray}, in_data :: Vector{mx.NDArray}, out_data :: Vector{mx.NDArray}, in_grad :: Vector{mx.NDArray})
info("Entering backward")
label = in_data[2]
y = out_data[1]
dx = in_grad[1]

@mx.nd_as_jl ro=(label, y) rw=dx begin
dx[:] = y
end
info("Leaving backward")
end

#define mlp
data = mx.Variable("data")
fc1 = mx.FullyConnected(data = data, name="fc1", num_hidden=128)
act1 = mx.Activation(data = fc1, name="relu1", act_type="relu")
fc2 = mx.FullyConnected(data = act1, name="fc2", num_hidden=64)
act2 = mx.Activation(data = fc2, name="relu2", act_type="relu")
fc3 = mx.FullyConnected(data = act2, name="fc3", num_hidden=10)

# Setup Native operator
mysoftmax = JuliaSoftmax()
mlp = mysoftmax(name = "softmax", data=fc3)

model = mx.FeedForward(mlp, context = mx.cpu())
optimizer = mx.SGD(lr = 0.1, momentum = 0.9, weight_decay = 0.00001)

include("mnist/mnist-data.jl")
train_provider, eval_provider = get_mnist_providers(100)
mx.fit(model, optimizer, train_provider, eval_data=eval_provider, n_epoch =20 )

1 change: 1 addition & 0 deletions src/MXNet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ include("kvstore.jl")
include("callback.jl")
include("model.jl")

include("nativeops.jl")
include("visualize.jl")

include("nn-factory.jl")
Expand Down
15 changes: 15 additions & 0 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ macro mxcall(fv, argtypes, args...)
end
end

"Utility macro to call MXNet API functions from a threadpool"
macro mxthreadcall(fv, argtypes, args...)
f = eval(fv)
args = map(esc, args)
quote
_mxret = @threadcall( ($(Meta.quot(f)), $MXNET_LIB),
Cint, $argtypes, $(args...) )
if _mxret != 0
err_msg = mx_get_last_error()
throw(MXError(err_msg))
end
end
end


################################################################################
# Handle types
################################################################################
Expand Down
8 changes: 6 additions & 2 deletions src/executor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ function forward(self :: Executor; is_train::Bool=false, kwargs...)
copy!(self.arg_dict[k], v)
end

@mxcall(:MXExecutorForward, (MX_handle, Cint), self, is_train)
println("Forward")
@mxthreadcall(:MXExecutorForward, (MX_handle, Cint), self, is_train)
println("Forward")
end

function backward(self :: Executor)
Expand All @@ -175,7 +177,9 @@ function backward(self :: Executor, out_grad :: NDArray)
end
function backward(self :: Executor, out_grads :: Vector{NDArray})
out_grads = MX_handle[out_grads...]
@mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}), self, length(out_grads), out_grads)
println("Backward")
@mxthreadcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}), self, length(out_grads), out_grads)
println("Backward")
end


Expand Down
Loading