Skip to content

Commit

Permalink
Fractal example
Browse files Browse the repository at this point in the history
Environment loop fixes, plus allow for cancelling loop variants

Some adjustments to vector and math interfaces

Call(), CallWrap(), HandleDim(), and TwoArrays() now take the name instead of the function itself to provide better error messages

Added elements() method to array
  • Loading branch information
ggcrunchy committed Dec 17, 2015
1 parent 5e2c951 commit f346637
Show file tree
Hide file tree
Showing 21 changed files with 281 additions and 214 deletions.
5 changes: 2 additions & 3 deletions scripts/lib/funcs/image.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
--- Image-related functions.

-- Modules --
local af = require("arrayfire")
local array = require("lib.impl.array")

-- Imports --
Expand All @@ -14,7 +13,7 @@ local M = {}

--
local function HistEqual (in_arr, hist)
return CallWrap(af.af_hist_equal, in_arr:get(), hist:get())
return CallWrap("af_hist_equal", in_arr:get(), hist:get())
end

--
Expand All @@ -31,7 +30,7 @@ function M.Add (into)
minval, maxval = lib.min("f64", in_arr), lib.max("f64", in_arr)
end

return CallWrap(af.af_histogram, in_arr:get(), nbins, minval, maxval)
return CallWrap("af_histogram", in_arr:get(), nbins, minval, maxval)
end
} do
into[k] = v
Expand Down
14 changes: 7 additions & 7 deletions scripts/lib/funcs/io.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,44 @@ local M = {}
-- See also: https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/imageio.cpp

local function LoadImage (filename, is_color)
return CallWrap(af.af_load_image, filename, is_color)
return CallWrap("af_load_image", filename, is_color)
end

--
local function SaveImage (filename, in_arr)
Call(af.af_save_image, filename, in_arr:get())
Call("af_save_image", filename, in_arr:get())
end

--
function M.Add (into)
for k, v in pairs{
deleteImageMem = function(ptr)
Call(af.af_delete_image_memory, ptr)
Call("af_delete_image_memory", ptr)
end,

--
loadImage = LoadImage, loadimage = LoadImage,

--
loadImageMem = function(ptr)
return CallWrap(af.af_load_image_memory, ptr)
return CallWrap("af_load_image_memory", ptr)
end,

loadImageNative = function(filename)
return CallWrap(af.af_load_image_native, filename)
return CallWrap("af_load_image_native", filename)
end,

--
saveImage = SaveImage, saveimage = SaveImage,

--
saveImageMem = function(in_arr, format)
return Call(af.af_save_image_memory, in_arr:get(), af[format or "AF_FIF_PNG"])
return Call("af_save_image_memory", in_arr:get(), af[format or "AF_FIF_PNG"])
end,

--
saveImageNative = function(filename, in_arr)
Call(af.af_save_image_native, filename, in_arr:get())
Call("af_save_image_native", filename, in_arr:get())
end
} do
into[k] = v
Expand Down
44 changes: 22 additions & 22 deletions scripts/lib/funcs/linear_algebra.lua
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local M = {}
local function MatMul2 (a, b, opt_lhs, opt_rhs)
opt_lhs, opt_rhs = opt_lhs or "AF_MAT_NONE", opt_rhs or "AF_MAT_NONE"

return CallWrap(af.af_matmul, a:get(), b:get(), af[opt_lhs], af[opt_rhs])
return CallWrap("af_matmul", a:get(), b:get(), af[opt_lhs], af[opt_rhs])
end

--
Expand Down Expand Up @@ -68,7 +68,7 @@ function M.Add (into)
is_upper = true
end

local res, info = Call(af.af_cholesky, in_arr:get(), is_upper)
local res, info = Call("af_cholesky", in_arr:get(), is_upper)

out:set(res)

Expand All @@ -81,28 +81,28 @@ function M.Add (into)
is_upper = true
end

return Call(af.af_cholesky_inplace, in_arr:get(), is_upper)
return Call("af_cholesky_inplace", in_arr:get(), is_upper)
end,

--
det = function(rtype, arr)
return ToType(rtype, Call(af.af_det, arr:get()))
return ToType(rtype, Call("af_det", arr:get()))
end,

--
dot = function(a, b, opt_lhs, opt_rhs)
return CallWrap(af.af_dot, a:get(), b:get(), af[opt_lhs or "AF_MAT_NONE"], af[opt_rhs or "AF_MAT_NONE"])
return CallWrap("af_dot", a:get(), b:get(), af[opt_lhs or "AF_MAT_NONE"], af[opt_rhs or "AF_MAT_NONE"])
end,

--
inverse = function(arr, options)
return CallWrap(arr:get(), af[options or "AF_MAT_NONE"])
return CallWrap("af_inverse", arr:get(), af[options or "AF_MAT_NONE"])
end,

--
lu = function(a, b, c, d)
if IsArray(d) then -- a: lower, b: upper, c: pivot, d: in
local l, u, p = Call(af.af_lu, d:get())
local l, u, p = Call("af_lu", d:get())

a:set(l)
b:set(u)
Expand All @@ -112,8 +112,8 @@ function M.Add (into)
d = true
end

a:set(Call(af.af_copy_array, c:get()))
b:set(Call(af.af_lu_inplace, a:get(), d))
a:set(Call("af_copy_array", c:get()))
b:set(Call("af_lu_inplace", a:get(), d))
end
end,

Expand All @@ -123,7 +123,7 @@ function M.Add (into)
is_lapack_piv = true
end

pivot:set(Call(af.af_lu_inplace, in_arr:get(), is_lapack_piv))
pivot:set(Call("af_lu_inplace", in_arr:get(), is_lapack_piv))
end,

--
Expand Down Expand Up @@ -154,57 +154,57 @@ function M.Add (into)

--
norm = function(arr, norm_type, p, q)
return Call(af.af_norm, arr:get(), af[norm_type or "AF_NORM_EUCLID"], p or 1, q or 1)
return Call("af_norm", arr:get(), af[norm_type or "AF_NORM_EUCLID"], p or 1, q or 1)
end,

--
qr = function(a, b, c, d)
if IsArray(d) then -- a: q, b: r, c: tau, d: in
local q, r, t = Call(af.af_qr, d:get())
local q, r, t = Call("af_qr", d:get())

a:set(q)
b:set(r)
c:set(t)
else -- a: out, b: tau, c: in
a:set(Call(af.af_copy_array, c:get()))
b:set(Call(af.af_qr_inplace, a:get()))
a:set(Call("af_copy_array", c:get()))
b:set(Call("af_qr_inplace", a:get()))
end
end,

--
qrInPlace = function(tau, in_arr)
tau:set(Call(af.af_qr_inplace, in_arr:get()))
tau:set(Call("af_qr_inplace", in_arr:get()))
end,

--
rank = function(arr, tol)
return Call(af.af_rank, arr:get(), tol or 1e-5)
return Call("af_rank", arr:get(), tol or 1e-5)
end,

--
solve = function(a, b, options)
return CallWrap(af.af_solve, a:get(), b:get(), af[options or "AF_MAT_NONE"])
return CallWrap("af_solve", a:get(), b:get(), af[options or "AF_MAT_NONE"])
end,

--
solveLU = function(a, piv, b, options)
return CallWrap(af.af_solve_lu, a:get(), piv:get(), b:get(), af[options or "AF_MAT_NONE"])
return CallWrap("af_solve_lu", a:get(), piv:get(), b:get(), af[options or "AF_MAT_NONE"])
end,

--
svd = SVD(af.af_svd),
svd = SVD("af_svd"),

--
svdInPlace = SVD(af.af_svd_inplace),
svdInPlace = SVD("af_svd_inplace"),

--
transpose = function(arr, conjugate)
return CallWrap(af.af_transpose, arr:get(), conjugate)
return CallWrap("af_transpose", arr:get(), conjugate)
end,

--
transposeInPlace = function(arr, conjugate)
return Call(af.af_transpose_inplace, arr:get(), conjugate)
return Call("af_transpose_inplace", arr:get(), conjugate)
end
} do
into[k] = v
Expand Down
20 changes: 13 additions & 7 deletions scripts/lib/funcs/mathematics.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,33 @@ local array = require("lib.impl.array")

-- Imports --
local CallWrap = array.CallWrap
local IsArray = array.IsArray
local TwoArrays = array.TwoArrays

-- Exports --
local M = {}

-- See also: https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/binary.cpp

--
local function Binary (func)
return function(a, b, batch)
return TwoArrays(func, a, b, batch)
local function Binary (name)
return function(a, b)
return TwoArrays(name, a, b--[[TODO: IsArray(a) and IsArray(b) and gfor_get]])
end
end

--
local function Unary (func)
local function Unary (name)
return function(in_arr)
return CallWrap(func, in_arr:get())
return CallWrap(name, in_arr:get())
end
end

local function LoadFuncs (into, funcs, op)
for _, v in ipairs(funcs) do
local func = af["af_" .. v]
local name = "af_" .. v

into[v] = func and op(func) -- ignore conditionally unavailable functions
into[v] = af[name] and op(name) -- ignore conditionally unavailable functions
end
end

Expand Down Expand Up @@ -103,6 +106,9 @@ function M.Add (into)
"root",
"sub"
}, Binary)

-- Use C++ name. (TODO: maxof, minof... re. vector)
into.complex, into.cplx2 = into.cplx2
end

-- Export the module.
Expand Down
8 changes: 4 additions & 4 deletions scripts/lib/funcs/signal_processing.lua
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function M.Add (into)
--
convolve = function(a, b, c, d)
if IsArray(c) then -- a: col_filter, b: row_filter, c: signal, d: mode
return CallWrap(af.af_convolve2_sep, a:get(), b:get(), c:get(), af[mode or "AF_CONV_DEFAULT"])
return CallWrap("af_convolve2_sep", a:get(), b:get(), c:get(), af[mode or "AF_CONV_DEFAULT"])
else -- a: signal, b: filter, c: mode, d: domain
local n, func = min(a:numdims(), b:numdims())

Expand All @@ -64,17 +64,17 @@ function M.Add (into)

--
fft = function(in_arr, dim0)
return CallWrap(af.af_fft, in_arr:get(), 1, dim0 or 0)
return CallWrap("af_fft", in_arr:get(), 1, dim0 or 0)
end,

--
fft2 = function(in_arr, dim0, dim1)
return CallWrap(af.af_fft2, in_arr:get(), 1, dim0 or 0, dim1 or 0)
return CallWrap("af_fft2", in_arr:get(), 1, dim0 or 0, dim1 or 0)
end,

--
fft3 = function(in_arr, dim0, dim1, dim2)
return CallWrap(af.af_fft3, in_arr:get(), 1, dim0 or 0, dim1 or 0, dim2 or 0)
return CallWrap("af_fft3", in_arr:get(), 1, dim0 or 0, dim1 or 0, dim2 or 0)
end
} do
into[k] = v
Expand Down
5 changes: 2 additions & 3 deletions scripts/lib/funcs/static.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
--- Static companions of array methods.

-- Modules --
local af = require("arrayfire")
local array = require("lib.impl.array")

-- Imports --
Expand Down Expand Up @@ -46,14 +45,14 @@ function M.Add (into)
getDims = function(arr, out)
out = out or {}

out[1], out[2], out[3], out[4] = Call(af.af_get_dims, arr:get())
out[1], out[2], out[3], out[4] = Call("af_get_dims", arr:get())

return out
end,

--
numDims = function(arr)
return Call(af.af_get_numdims, arr:get())
return Call("af_get_numdims", arr:get())
end,
} do
into[k] = v
Expand Down
9 changes: 4 additions & 5 deletions scripts/lib/funcs/statistics.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
local type = type

-- Modules --
local af = require("arrayfire")
local array = require("lib.impl.array")

-- Imports --
Expand All @@ -23,14 +22,14 @@ local M = {}
local function Mean (a, b, c)
if type(a) == "string" then -- a: type, b: in_arr, c: weights
if IsArray(c) then
return ToType(a, Call(af.af_mean_all_weighted, b:get(), c:get()))
return ToType(a, Call("af_mean_all_weighted", b:get(), c:get()))
else
return ToType(a, Call(af.af_mean_all, b:get()))
return ToType(a, Call("af_mean_all", b:get()))
end
elseif IsArray(b) then -- a: arr, b: weights, c: dim
return CallWrap(af.mean_weighted, a:get(), b, GetFNSD(c))
return CallWrap("mean_weighted", a:get(), b, GetFNSD(c))
else -- a: arr, b: dim
return HandleDim(af.af_mean, a, b)
return HandleDim("af_mean", a, b)
end
end

Expand Down
3 changes: 1 addition & 2 deletions scripts/lib/funcs/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
local print = print

-- Modules --
local af = require("arrayfire")
local array = require("lib.impl.array")

-- Imports --
Expand All @@ -18,7 +17,7 @@ function M.Add (into)
for k, v in pairs{
--
print = function(exp, arr, precision)
Call(af.af_print_array_gen, exp, arr:get(), precision or 4) -- https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/util.cpp
Call("af_print_array_gen", exp, arr:get(), precision or 4) -- https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/util.cpp
end,

--
Expand Down
Loading

0 comments on commit f346637

Please sign in to comment.