Skip to content

Commit

Permalink
add runtime version of macro gendict so that @defvar(m,x[1:3]) and T=…
Browse files Browse the repository at this point in the history
…1:3;@defvar(m,x[T]) have the same semantics
  • Loading branch information
joehuchette committed Sep 27, 2015
1 parent 6b179d2 commit d032934
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 4 deletions.
84 changes: 82 additions & 2 deletions src/JuMPContainer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Base.isempty(d::JuMPContainer) = isempty(_innercontainer(d))
# 0:K -- range with compile-time starting index
# S -- general iterable set
export @gendict
macro gendict(instancename,T,idxpairs,idxsets...)
macro gendict(instancename,T,idxsets...)
N = length(idxsets)
allranges = all(s -> (isexpr(s,:(:)) && length(s.args) == 2), idxsets)
truearray = allranges && all(s -> s.args[1] == 1, idxsets)
Expand Down Expand Up @@ -165,12 +165,92 @@ macro gendict(instancename,T,idxpairs,idxsets...)
=#
else
# JuMPDict
escidxs = [esc(idxset) for idxset in idxsets]
return :(
$(esc(instancename)) = JuMPDict{$T,$N}()
$(esc(instancename)) = (if is_unit_ranges($(escidxs...))
runtime_gendict($T, $(escidxs...))
else
JuMPDict{$T,$N}()
end)
)
end
end

function runtime_gendict(T,idxsets...)
N = length(idxsets)
allranges = all(s -> (typeof(s) <: Range), idxsets)
truearray = allranges && all(s -> (first(s) == 1), idxsets)
if allranges
if truearray
return Array(T, [last(rng) for rng in idxsets]...)
else
typename = symbol(string("JuMPArray",gensym()))
dictnames = Array(Symbol,N)
# JuMPArray
offset = Array(Int,N)
for i in 1:N
offset[i] = 1 - first(idxsets[i])
end
typecode = quote
type $(typename){T} <: JuMPArray{T,$N}
innerArray::Array{T,$N}
meta::Dict{Symbol,Any}
end
end
constrlhs = :($(typename)(innerArray::Array))
constrrhs = :($(typename)(innerArray, Dict{Symbol,Any}()))
getidxlhs = :(Base.getindex(d::$(typename)))
setidxlhs = :(Base.setindex!(d::$(typename),val))
getidxrhs = :(Base.getindex(d.innerArray))
setidxrhs = :(Base.setindex!(d.innerArray,val))
maplhs = :(Base.map(f::Function,d::$(typename)))
maprhs = :($(typename)(map(f,d.innerArray),d.meta))
wraplhs = :(JuMPContainer_from(d::$(typename),inner)) # helper function that wraps array into JuMPArray of similar type
wraprhs = :($(typename)(inner))

nextidxlhs = :(_next_index(d::$(typename), k))
# build up exprs for _next_index
lidxsets = [ii => symbol(string("locidxset",ii)) for ii in 1:N]
nextidxrhs = quote
subidx = ind2sub(size(d), k)
$(Expr(:tuple, [:(subidx[$ii] - $(offset[ii])) for ii in 1:N]...))
end
for i in 1:N
varname = symbol(string("x",i))

push!(getidxlhs.args,:($varname))
push!(setidxlhs.args,:($varname))

push!(getidxrhs.args,:(isa($varname, Int) ? $varname+$(offset[i]) : $varname ))
push!(setidxrhs.args,:($varname+$(offset[i])))

end

badgetidxlhs = :(Base.getindex(d::$(typename),wrong...))
badgetidxrhs = :(data = printdata(d);
error("Wrong number of indices for ",data.name, ", expected ",length(data.indexsets)))

funcs = quote
$constrlhs = $constrrhs
$getidxlhs = $getidxrhs
$setidxlhs = $setidxrhs
$maplhs = $maprhs
$badgetidxlhs = $badgetidxrhs
$wraplhs = $wraprhs
$nextidxlhs = $nextidxrhs
end

eval(Expr(:toplevel, typecode))
eval(Expr(:toplevel, funcs))

return eval(:($(typename)(Array($T, [length(idxset) for idxset in $idxsets]...))))
end
else
error("Should not reach this point")
end
end

@generated is_unit_ranges(idxsets...) = :($(all(s -> s <: UnitRange{Int}, idxsets)))
pushmeta!(x::JuMPContainer, sym::Symbol, val) = (x.meta[sym] = val)
getmeta(x::JuMPContainer, sym::Symbol) = x.meta[sym]

Expand Down
4 changes: 2 additions & 2 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ function getloopedcode(c::Expr, code, condition, idxvars, idxsets, idxpairs, sym
N = length(idxsets)
mac = :($(esc(varname)) = JuMPDict{$(sym),$N}())
else
mac = Expr(:macrocall,symbol("@gendict"),esc(varname),sym,idxpairs,idxsets...)
mac = Expr(:macrocall,symbol("@gendict"),esc(varname),sym,idxsets...)
end
return quote
$mac
Expand Down Expand Up @@ -877,7 +877,7 @@ macro defConstrRef(var)
idxsets = var.args[2:end]
idxpairs = IndexPair[]

mac = Expr(:macrocall,symbol("@gendict"),varname,:ConstraintRef,idxpairs, idxsets...)
mac = Expr(:macrocall,symbol("@gendict"), varname, :ConstraintRef, idxsets...)
code = quote
$(esc(mac))
nothing
Expand Down
20 changes: 20 additions & 0 deletions test/perf/JuMPArray-iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,23 @@ bench(1)
bench(100)
bench(1000)
bench(2000)

function bench_runtime(n)
t1 = @elapsed begin
m = Model()
I, J = 1:n, 2:n
@defVar(m, x[I,J])
end
t2 = @elapsed begin
cntr = 0
for (ii,jj,v) in x
cntr += ii + jj + v.col
end
end
t1, t2
end

bench_runtime(1)
bench_runtime(100)
bench_runtime(1000)
bench_runtime(2000)
64 changes: 64 additions & 0 deletions test/perf/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,67 @@ for N in [20,50,100]
println(" N=$(N) min $(minimum(N2_times))")
end

function test_linear_runtime(N)
m = Model()
I,J = 1:10N, 1:5N
@defVar(m, x[I,J])
K = 1:N
@defVar(m, y[K,K,K])

for z in 1:10
@addConstraint(m,
9*y[1,1,1] - 5*y[N,N,N] -
2*sum{ z*x[j,i*N], j=((z-1)*N+1):z*N, i=3:4} +
sum{ i*(9*x[i,j] + 3*x[j,i]), i=N:2N, j=N:2N} +
x[1,1] + x[10N,5N] + x[2N,1] +
1*y[1,1,N] + 2*y[1,N,1] + 3*y[N,1,1] +
y[N,N,N] - 2*y[N,N,N] + 3*y[N,N,N]
<=
sum{sum{sum{N*i*j*k*y[i,j,k] + x[i,j],k=1:N; i!=j && j!=k},j=1:N},i=1:N} +
sum{sum{x[i,j], j=1:5N; j % i == 3}, i=1:10N; i <= N*z}
)
end
end

function test_quad_runtime(N)
m = Model()
I,J = 1:10N, 1:5N
@defVar(m, x[I,J])
K = 1:N
@defVar(m, y[K,K,K])

for z in 1:10
@addConstraint(m,
9*y[1,1,1] - 5*y[N,N,N] -
2*sum{ z*x[j,i*N], j=((z-1)*N+1):z*N, i=3:4} +
sum{ i*(9*x[i,j] + 3*x[j,i]), i=N:2N, j=N:2N} +
x[1,1] + x[10N,5N] * x[2N,1] +
1*y[1,1,N] * 2*y[1,N,1] + 3*y[N,1,1] +
y[N,N,N] - 2*y[N,N,N] * 3*y[N,N,N]
<=
sum{sum{sum{N*i*j*k*y[i,j,k] * x[i,j],k=1:N; i!=j && j!=k},j=1:N},i=1:N} +
sum{sum{x[i,j], j=1:5N; j % i == 3}, i=1:10N; i <= N*z}
)
end
end


# Warmup
println("Test 2 (runtime)")
test_linear_runtime(1)
test_quad_runtime(1)
for N in [20,50,100]
println(" Running N=$(N)...")
N1_times = {}
N2_times = {}
for iter in 1:10
tic()
test_linear_runtime(N)
push!(N1_times, toq())
tic()
test_quad_runtime(N)
push!(N2_times, toq())
end
println(" N=$(N) min $(minimum(N1_times))")
println(" N=$(N) min $(minimum(N2_times))")
end

0 comments on commit d032934

Please sign in to comment.