Skip to content

Commit 3a79bf2

Browse files
committed
start thinking about pipelines and native modules
1 parent 9ea8073 commit 3a79bf2

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

src/module/Module.jl

+2
Original file line numberDiff line numberDiff line change
@@ -516,5 +516,7 @@ end
516516

517517
# include implementations
518518
include("symbol_module.jl")
519+
include("pipeline.jl")
520+
include("native_module.jl")
519521

520522
end

src/module/native_module.jl

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
NativeModule
3+
4+
Allows the implementation of a MXNet module in native Julia. NDArrays
5+
will be translated into native Julia arrays.
6+
"""
7+
type NativeModule{F<:Function,B<:Function} <: AbstractModule
8+
forward :: F
9+
backward :: B
10+
end

src/module/pipeline.jl

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
abstract PipelineModule <: AbstractModule
2+
3+
"""
4+
SimplePipelineModule
5+
6+
Allows the pipelining of several modules.
7+
8+
# Arguments:
9+
* `pipeline :: Vector{Module}`
10+
The elements that are called sequentially
11+
12+
# Functionality
13+
*
14+
"""
15+
type SimplePipelineModule <: PipelineModule
16+
pipeline :: Vector{Module}
17+
end
18+
19+
type ModuleDataProvider <: mx.AbstractDataProvider
20+
mod :: Module
21+
end
22+
23+
24+
function forward(self :: SimplePipelineModule)
25+
for mod in self.pipeline
26+
forward(mod)
27+
end
28+
end
29+
30+
function backward(self :: SimplePipelineModule)
31+
for i in length(self.pipeline):-1:1
32+
mod = self.pipeline[i]
33+
backward(mod)
34+
end
35+
end
36+
37+
function get_outputs(self :: SimplePipelineModule)
38+
return get_outputs(last(self.pipeline))
39+
end

test/unittest/symbol-module.jl

+3
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ function test_linear_regression(n_epoch::Int = 10)
133133
@test sum(abs(ha_pred-y_pred)) < 1e-1
134134
end
135135

136+
function test_simplepipeline()
137+
end
138+
136139
################################################################################
137140
# Run tests
138141
################################################################################

0 commit comments

Comments
 (0)