From b54d8604977eee4e1e4ec9584aa6c998b0aec810 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Thu, 17 Dec 2015 19:08:09 -0600 Subject: [PATCH] Forgot to update convolve to use name-based call Fix for af_moddims binding, added to library too Added flat() SWE example Additional timer support progress() function for samples Added wait_for_windows() predicate-builders --- scripts/lib/funcs/signal_processing.lua | 4 +- scripts/lib/impl/array.lua | 6 +- scripts/lib/methods/constructors.lua | 20 +++++ scripts/lib/misc/program.lua | 103 ++++++++++++++++------ scripts/tests/pde/swe.lua | 20 ++--- src/Lua/arrayfire/methods/MoveReorder.cpp | 2 +- 6 files changed, 115 insertions(+), 40 deletions(-) diff --git a/scripts/lib/funcs/signal_processing.lua b/scripts/lib/funcs/signal_processing.lua index d167f73..d3d9391 100644 --- a/scripts/lib/funcs/signal_processing.lua +++ b/scripts/lib/funcs/signal_processing.lua @@ -24,10 +24,10 @@ local M = {} -- local function Convolve (dim) - local func = af["af_convolve" .. dim] + local name = "af_convolve" .. dim return function(signal, filter, mode, domain) - return CallWrap(func, signal:get(), filter:get(), af[mode or "AF_CONV_DEFAULT"], af[domain or "AF_CONV_AUTO"]) + return CallWrap(name, signal:get(), filter:get(), af[mode or "AF_CONV_DEFAULT"], af[domain or "AF_CONV_AUTO"]) end end diff --git a/scripts/lib/impl/array.lua b/scripts/lib/impl/array.lua index c59974e..6f1b17a 100644 --- a/scripts/lib/impl/array.lua +++ b/scripts/lib/impl/array.lua @@ -6,6 +6,7 @@ local getmetatable = getmetatable local rawequal = rawequal local setmetatable = setmetatable local tostring = tostring +local traceback = debug and debug.traceback local type = type -- Modules -- @@ -36,7 +37,10 @@ local Constants = setmetatable({}, { __mode = "k" }) -- local function CallFromName_Checked (name, ...) if type(name) ~= "string" then - print(debug.traceback()) + if traceback then + print(traceback()) + end + error("Expected string name, got: " .. tostring(name)) end diff --git a/scripts/lib/methods/constructors.lua b/scripts/lib/methods/constructors.lua index d3fc6a8..fd4b16f 100644 --- a/scripts/lib/methods/constructors.lua +++ b/scripts/lib/methods/constructors.lua @@ -189,9 +189,29 @@ function M.Add (into) return CallWrap(extract and "af_diag_extract" or "af_diag_create", in_arr:get(), num or 0) end, + -- + flat = function(in_arr) + return CallWrap("af_flat", in_arr:get()) + end, + -- identity = DimsAndTypeFunc("af_identity"), + -- + moddims = function(in_arr, a, b, c, d) + local ndims, dims + + if type(a) == "table" then -- a: dims + ndims, dims = GetNDims(a), a + elseif type(b) == "table" then -- a: ndims, b: dims + ndims, dims = a, b + else -- a: d0, b: d1, c: d2, d: d3 + ndims, dims = 4, PrepDims(a, b, c, d) + end + + return CallWrap("af_moddims", in_arr:get(), ndims, dims) + end, + -- randn = DimsAndTypeFunc("af_randn"), diff --git a/scripts/lib/misc/program.lua b/scripts/lib/misc/program.lua index 8fdc608..c841f5e 100644 --- a/scripts/lib/misc/program.lua +++ b/scripts/lib/misc/program.lua @@ -4,17 +4,26 @@ local clock = os.clock local error = error local getenv = os.getenv +local max = math.max local pcall = pcall local print = print local select = select local tonumber = tonumber -- Modules -- -local af = require("arrayfire") +local array = require("lib.impl.array") + +-- Imports -- +local GetLib = array.GetLib -- Exports -- local M = {} +-- -- +local IterPrev = 0 +local TimePrev = 0 +local MaxRate = 0 + -- -- local T0 @@ -106,10 +115,11 @@ void read_idx(std::vector &dims, std::vector &data, const char *name) local ok, err = pcall(function() local device = argc > 1 and tonumber(argv[1]) or 0 + local Lib = GetLib() -- Select a device and display arrayfire info - af.af_set_device(device) - af.af_info() + Lib.setDevice(device) + Lib.info() func(argc, argv) @@ -133,34 +143,32 @@ void read_idx(std::vector &dims, std::vector &data, const char *name) -- Are the windows redirecting IO? -- - progress = function() ---[[ -bool progress(unsigned iter_curr, af::timer t, double time_total) -{ - static unsigned iter_prev = 0; - static double time_prev = 0; - static double max_rate = 0; + progress = function(iter_curr, t, time_total) + local Lib = GetLib() - af::sync(); - double time_curr = af::timer::stop(t); + Lib.sync() - if ((time_curr - time_prev) < 1) return true; + local time_curr = Lib.timer_stop(t) - double rate = (iter_curr - iter_prev) / (time_curr - time_prev); - printf(" iterations per second: %.0f (progress %.0f%%)\n", - rate, 100.0f * time_curr / time_total); + if time_curr - TimePrev < 1 then + return true + end - max_rate = std::max(max_rate, rate); + local rate = (iter_curr - IterPrev) / (time_curr - TimePrev) - iter_prev = iter_curr; - time_prev = time_curr; + Lib.printf(" iterations per second: %.0f (progress %.0f%%)", rate, 100.0 * time_curr / time_total) + MaxRate = max(MaxRate, rate) - if (time_curr < time_total) return true; + IterPrev = iter_curr + TimePrev = time_curr - printf(" ### %f iterations per second (max)\n", max_rate); - return false; -]] + if time_curr < time_total then + return true + end + + Lib.printf(" ### %f iterations per second (max)", MaxRate) + return false end, -- @@ -178,10 +186,53 @@ bool progress(unsigned iter_curr, af::timer t, double time_total) end, -- - timer_stop = function() - return clock() - T0 + timer_stop = function(since) + return clock() - (since or T0) + end, + + -- + wait_for_windows = function(how, w1, w2, w3) + -- + local function any_closed () + if w3 then + return w1:close() or w2:close() or w3:close() + elseif w2 then + return w1:close() or w2:close() + else + return w1:close() + end + end + + -- + if how == "until" then + return any_closed + else + return function() + return not any_closed() + end + end + end, + + -- + wait_for_windows_close = function(how, w1, w2, w3) + local done = GetLib().wait_for_windows(how, w1, w2, w3) + + return function() + if done() then + w1:destroy() -- TODO: Do this right... how? + + if w2 then + w2:destroy() + end + + if w3 then + w3:destroy() + end + end + end + -- end - -- ^^^ TODO: Better stuff? (resolution, instantiable) + -- ^^^ TODO: Better stuff? (resolution, instantiable?) } do into[k] = v end diff --git a/scripts/tests/pde/swe.lua b/scripts/tests/pde/swe.lua index 7644fa8..16e9a87 100644 --- a/scripts/tests/pde/swe.lua +++ b/scripts/tests/pde/swe.lua @@ -4,7 +4,7 @@ local floor = math.floor -- Modules -- local AF = require("lib.af_lib") -lib.main(function(argc, arv) +AF.main(function(argc, arv) local win local function normalize (a, max) local mx = max * 0.5 @@ -14,32 +14,32 @@ lib.main(function(argc, arv) local function swe (console) local time_total = 20 -- run for N seconds -- Grid length, number and spacing - local Lx, nx = 512, Lx + 1 - local Ly, ny = 512, Ly + 1 + local Lx, Ly = 512, 512 + local nx, ny = Lx + 1, Ly + 1 local dx = Lx / (nx - 1) local dy = Ly / (ny - 1) local ZERO = AF.constant(0, nx, ny) local um, vm = ZERO:copy(), ZERO:copy() local io, jo, k = floor(Lx / 5.0), floor(Ly / 5.0), 20 - -- local x = tile(moddims(AF.seq(nx),nx,1), 1,ny) - -- local y = tile(moddims(AF.seq(ny),1,ny), nx,1) + local x = AF.tile(AF.moddims(AF.array(AF.seq(nx)),nx,1), 1,ny) + local y = AF.tile(AF.moddims(AF.array(AF.seq(ny)),1,ny), nx,1) -- Initial condition local etam = 0.01 * AF.exp((-((x - io) * (x - io) + (y - jo) * (y - jo))) / (k * k)) local m_eta = AF.max("f32", etam) local eta = etam:copy() local dt = 0.5 -- conv kernels - local h_diff_kernel[] = {9.81 * (dt / dx), 0, -9.81 * (dt / dx)} - local h_lap_kernel[] = {0, 1, 0, 1, -4, 1, 0, 1, 0} + local h_diff_kernel = {9.81 * (dt / dx), 0, -9.81 * (dt / dx)} + local h_lap_kernel = {0, 1, 0, 1, -4, 1, 0, 1, 0} local h_diff_kernel_arr = AF.array(3, h_diff_kernel) local h_lap_kernel_arr = AF.array(3, 3, h_lap_kernel) if not console then win = AF.Window(512, 512,"Shallow Water Equations") win:setColorMap("AF_COLORMAP_MOOD") end --- timer t = timer::start(); + local t = AF.timer_start() local iter = 0 - AF.EnvLoopWhile_Args(function(env) + AF.EnvLoopWhile_Mode(function(env) -- compute local up = um + AF.convolve(eta, h_diff_kernel_arr) local vp = um + AF.convolve(eta, h_diff_kernel_arr:T()) @@ -55,7 +55,7 @@ lib.main(function(argc, arv) end iter = iter + 1 end, function() - -- return AF.progress(iter, t, time_total) + return AF.progress(iter, t, time_total) end, "normal_gc") -- evict old states every now and then end diff --git a/src/Lua/arrayfire/methods/MoveReorder.cpp b/src/Lua/arrayfire/methods/MoveReorder.cpp index 51aaa0a..d1c2d73 100644 --- a/src/Lua/arrayfire/methods/MoveReorder.cpp +++ b/src/Lua/arrayfire/methods/MoveReorder.cpp @@ -51,7 +51,7 @@ static const struct luaL_Reg move_reorder_methods[] = { af_array * arr_ud = NewArray(L);// arr, ndims, dims, arr_ud - af_err err = af_moddims(arr_ud, GetArray(L, 2), dims.GetNDims(), dims.GetDims()); + af_err err = af_moddims(arr_ud, GetArray(L, 1), dims.GetNDims(), dims.GetDims()); return PushErr(L, err); // arr, ndims, dims, err, arr_ud }