Skip to content

Commit

Permalink
Forgot to update convolve to use name-based call
Browse files Browse the repository at this point in the history
Fix for af_moddims binding, added to library too

Added flat()

SWE example

Additional timer support

progress() function for samples

Added wait_for_windows() predicate-builders
  • Loading branch information
ggcrunchy committed Dec 18, 2015
1 parent f346637 commit b54d860
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 40 deletions.
4 changes: 2 additions & 2 deletions scripts/lib/funcs/signal_processing.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ local M = {}

--
local function Convolve (dim)
local func = af["af_convolve" .. dim]
local name = "af_convolve" .. dim

return function(signal, filter, mode, domain)
return CallWrap(func, signal:get(), filter:get(), af[mode or "AF_CONV_DEFAULT"], af[domain or "AF_CONV_AUTO"])
return CallWrap(name, signal:get(), filter:get(), af[mode or "AF_CONV_DEFAULT"], af[domain or "AF_CONV_AUTO"])
end
end

Expand Down
6 changes: 5 additions & 1 deletion scripts/lib/impl/array.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ local getmetatable = getmetatable
local rawequal = rawequal
local setmetatable = setmetatable
local tostring = tostring
local traceback = debug and debug.traceback
local type = type

-- Modules --
Expand Down Expand Up @@ -36,7 +37,10 @@ local Constants = setmetatable({}, { __mode = "k" })
--
local function CallFromName_Checked (name, ...)
if type(name) ~= "string" then
print(debug.traceback())
if traceback then
print(traceback())
end

error("Expected string name, got: " .. tostring(name))
end

Expand Down
20 changes: 20 additions & 0 deletions scripts/lib/methods/constructors.lua
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,29 @@ function M.Add (into)
return CallWrap(extract and "af_diag_extract" or "af_diag_create", in_arr:get(), num or 0)
end,

--
flat = function(in_arr)
return CallWrap("af_flat", in_arr:get())
end,

--
identity = DimsAndTypeFunc("af_identity"),

--
moddims = function(in_arr, a, b, c, d)
local ndims, dims

if type(a) == "table" then -- a: dims
ndims, dims = GetNDims(a), a
elseif type(b) == "table" then -- a: ndims, b: dims
ndims, dims = a, b
else -- a: d0, b: d1, c: d2, d: d3
ndims, dims = 4, PrepDims(a, b, c, d)
end

return CallWrap("af_moddims", in_arr:get(), ndims, dims)
end,

--
randn = DimsAndTypeFunc("af_randn"),

Expand Down
103 changes: 77 additions & 26 deletions scripts/lib/misc/program.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@
local clock = os.clock
local error = error
local getenv = os.getenv
local max = math.max
local pcall = pcall
local print = print
local select = select
local tonumber = tonumber

-- Modules --
local af = require("arrayfire")
local array = require("lib.impl.array")

-- Imports --
local GetLib = array.GetLib

-- Exports --
local M = {}

-- --
local IterPrev = 0
local TimePrev = 0
local MaxRate = 0

-- --
local T0

Expand Down Expand Up @@ -106,10 +115,11 @@ void read_idx(std::vector<dim_t> &dims, std::vector<ty> &data, const char *name)

local ok, err = pcall(function()
local device = argc > 1 and tonumber(argv[1]) or 0
local Lib = GetLib()

-- Select a device and display arrayfire info
af.af_set_device(device)
af.af_info()
Lib.setDevice(device)
Lib.info()

func(argc, argv)

Expand All @@ -133,34 +143,32 @@ void read_idx(std::vector<dim_t> &dims, std::vector<ty> &data, const char *name)
-- Are the windows redirecting IO?

--
progress = function()
--[[
bool progress(unsigned iter_curr, af::timer t, double time_total)
{
static unsigned iter_prev = 0;
static double time_prev = 0;
static double max_rate = 0;
progress = function(iter_curr, t, time_total)
local Lib = GetLib()

af::sync();
double time_curr = af::timer::stop(t);
Lib.sync()

if ((time_curr - time_prev) < 1) return true;
local time_curr = Lib.timer_stop(t)

double rate = (iter_curr - iter_prev) / (time_curr - time_prev);
printf(" iterations per second: %.0f (progress %.0f%%)\n",
rate, 100.0f * time_curr / time_total);
if time_curr - TimePrev < 1 then
return true
end

max_rate = std::max(max_rate, rate);
local rate = (iter_curr - IterPrev) / (time_curr - TimePrev)

iter_prev = iter_curr;
time_prev = time_curr;
Lib.printf(" iterations per second: %.0f (progress %.0f%%)", rate, 100.0 * time_curr / time_total)

MaxRate = max(MaxRate, rate)

if (time_curr < time_total) return true;
IterPrev = iter_curr
TimePrev = time_curr

printf(" ### %f iterations per second (max)\n", max_rate);
return false;
]]
if time_curr < time_total then
return true
end

Lib.printf(" ### %f iterations per second (max)", MaxRate)
return false
end,

--
Expand All @@ -178,10 +186,53 @@ bool progress(unsigned iter_curr, af::timer t, double time_total)
end,

--
timer_stop = function()
return clock() - T0
timer_stop = function(since)
return clock() - (since or T0)
end,

--
wait_for_windows = function(how, w1, w2, w3)
--
local function any_closed ()
if w3 then
return w1:close() or w2:close() or w3:close()
elseif w2 then
return w1:close() or w2:close()
else
return w1:close()
end
end

--
if how == "until" then
return any_closed
else
return function()
return not any_closed()
end
end
end,

--
wait_for_windows_close = function(how, w1, w2, w3)
local done = GetLib().wait_for_windows(how, w1, w2, w3)

return function()
if done() then
w1:destroy() -- TODO: Do this right... how?

if w2 then
w2:destroy()
end

if w3 then
w3:destroy()
end
end
end
--
end
-- ^^^ TODO: Better stuff? (resolution, instantiable)
-- ^^^ TODO: Better stuff? (resolution, instantiable?)
} do
into[k] = v
end
Expand Down
20 changes: 10 additions & 10 deletions scripts/tests/pde/swe.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ local floor = math.floor
-- Modules --
local AF = require("lib.af_lib")

lib.main(function(argc, arv)
AF.main(function(argc, arv)
local win
local function normalize (a, max)
local mx = max * 0.5
Expand All @@ -14,32 +14,32 @@ lib.main(function(argc, arv)
local function swe (console)
local time_total = 20 -- run for N seconds
-- Grid length, number and spacing
local Lx, nx = 512, Lx + 1
local Ly, ny = 512, Ly + 1
local Lx, Ly = 512, 512
local nx, ny = Lx + 1, Ly + 1
local dx = Lx / (nx - 1)
local dy = Ly / (ny - 1)
local ZERO = AF.constant(0, nx, ny)
local um, vm = ZERO:copy(), ZERO:copy()
local io, jo, k = floor(Lx / 5.0), floor(Ly / 5.0), 20
-- local x = tile(moddims(AF.seq(nx),nx,1), 1,ny)
-- local y = tile(moddims(AF.seq(ny),1,ny), nx,1)
local x = AF.tile(AF.moddims(AF.array(AF.seq(nx)),nx,1), 1,ny)
local y = AF.tile(AF.moddims(AF.array(AF.seq(ny)),1,ny), nx,1)
-- Initial condition
local etam = 0.01 * AF.exp((-((x - io) * (x - io) + (y - jo) * (y - jo))) / (k * k))
local m_eta = AF.max("f32", etam)
local eta = etam:copy()
local dt = 0.5
-- conv kernels
local h_diff_kernel[] = {9.81 * (dt / dx), 0, -9.81 * (dt / dx)}
local h_lap_kernel[] = {0, 1, 0, 1, -4, 1, 0, 1, 0}
local h_diff_kernel = {9.81 * (dt / dx), 0, -9.81 * (dt / dx)}
local h_lap_kernel = {0, 1, 0, 1, -4, 1, 0, 1, 0}
local h_diff_kernel_arr = AF.array(3, h_diff_kernel)
local h_lap_kernel_arr = AF.array(3, 3, h_lap_kernel)
if not console then
win = AF.Window(512, 512,"Shallow Water Equations")
win:setColorMap("AF_COLORMAP_MOOD")
end
-- timer t = timer::start();
local t = AF.timer_start()
local iter = 0
AF.EnvLoopWhile_Args(function(env)
AF.EnvLoopWhile_Mode(function(env)
-- compute
local up = um + AF.convolve(eta, h_diff_kernel_arr)
local vp = um + AF.convolve(eta, h_diff_kernel_arr:T())
Expand All @@ -55,7 +55,7 @@ lib.main(function(argc, arv)
end
iter = iter + 1
end, function()
-- return AF.progress(iter, t, time_total)
return AF.progress(iter, t, time_total)
end, "normal_gc") -- evict old states every now and then
end

Expand Down
2 changes: 1 addition & 1 deletion src/Lua/arrayfire/methods/MoveReorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static const struct luaL_Reg move_reorder_methods[] = {

af_array * arr_ud = NewArray(L);// arr, ndims, dims, arr_ud

af_err err = af_moddims(arr_ud, GetArray(L, 2), dims.GetNDims(), dims.GetDims());
af_err err = af_moddims(arr_ud, GetArray(L, 1), dims.GetNDims(), dims.GetDims());

return PushErr(L, err); // arr, ndims, dims, err, arr_ud
}
Expand Down

0 comments on commit b54d860

Please sign in to comment.