diff --git a/scripts/lib/funcs/vector.lua b/scripts/lib/funcs/vector.lua index 27dff6d..95d5c36 100644 --- a/scripts/lib/funcs/vector.lua +++ b/scripts/lib/funcs/vector.lua @@ -1,6 +1,7 @@ --- Vector functions. -- Standard library imports -- +local assert = assert local min = math.min local select = select local type = type @@ -20,6 +21,7 @@ local ToType = array.ToType local M = {} -- See also: https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/reduce.cpp +-- https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/where.cpp -- local function Bool (value) @@ -172,7 +174,14 @@ function M.Add (into) end, -- - sum = ReduceNaN("sum") + sum = ReduceNaN("sum"), + + -- + where = function(in_arr) + assert(not GetLib().gforGet(), "WHERE can not be used inside GFOR") -- TODO: AF_ERR_RUNTIME); + + return CallWrap("af_where", in_arr:get()) + end } do into[k] = v end diff --git a/scripts/lib/impl/array.lua b/scripts/lib/impl/array.lua index 5b7faca..9cd3bb6 100644 --- a/scripts/lib/impl/array.lua +++ b/scripts/lib/impl/array.lua @@ -139,17 +139,10 @@ end --- DOCME -- @tparam LuaArray arr --- @bool remove -- @treturn ?|af_array|nil X -function M.GetHandle (arr, remove) +function M.GetHandle (arr) -- TODO: If proxy, add reference? - local ha = arr.m_handle - - if remove then - arr.m_handle = nil - end - - return ha + return arr.m_handle end -- -- @@ -188,6 +181,8 @@ function M.IsConstant (item) return not not Constants[item] -- metatable redundant; coerce to false if missing end +-- TODO: IsProxy(), MakeProxy()... + --- DOCME -- @tparam LuaArray arr -- @tparam ?|af_array|nil handle @@ -237,9 +232,8 @@ function M.ToType (ret_type, real, imag) if rtype == "c32" or rtype == "c64" then return { real = real, imag = imag } else - return real + return real -- TODO: Improve this! end - -- TODO: Improve these a bit end --- DOCME @@ -279,7 +273,7 @@ end function M.WrapArray (arr) local wrapped = setmetatable({ m_handle = arr }, ArrayMethodsAndMetatable) - _AddToCurrentEnvironment_(wrapped) + _AddToCurrentEnvironment_("array", wrapped) return wrapped end @@ -299,8 +293,8 @@ end for _, v in ipairs{ "lib.impl.ephemeral", "lib.impl.operators", - "lib.impl.seq", - "lib.impl.index", -- depends on seq + "lib.impl.seq", -- depends on ephemeral + "lib.impl.index", -- depends on ephemeral, seq "lib.methods.methods" } do require(v).Add(M, ArrayMethodsAndMetatable) @@ -309,6 +303,17 @@ end ArrayMethodsAndMetatable.__index = ArrayMethodsAndMetatable ArrayMethodsAndMetatable.__metatable = MetaValue +-- Register array environment type. +M.RegisterEnvironmentCleanup("array", function(arr) + local ha = arr:get() + + arr.m_handle = nil -- set() can error out + + return not ha or af.af_release_array(ha) == SUCCESS + -- TODO: pooling? +end, "Errors releasing %i arrays") +-- TODO: Register "array_proxy"? + -- By default, check valid names. M.CheckNames(true) diff --git a/scripts/lib/impl/ephemeral.lua b/scripts/lib/impl/ephemeral.lua index d1a758c..31d8b79 100644 --- a/scripts/lib/impl/ephemeral.lua +++ b/scripts/lib/impl/ephemeral.lua @@ -2,6 +2,7 @@ -- Standard library imports -- local assert = assert +local concat = table.concat local collectgarbage = collectgarbage local error = error local pairs = pairs @@ -9,12 +10,6 @@ local pcall = pcall local rawequal = rawequal local remove = table.remove --- Modules -- -local af = require("arrayfire") - --- Forward declarations -- -local IsArray - -- Cookies -- local _command = {} @@ -30,62 +25,88 @@ local Stack = {} -- -- local Top = 0 +-- +local function Remove (lists, elem) + for elem_type, list in pairs(lists) do + if list[elem] then + list[elem] = nil + + return elem_type + end + end +end + +-- -- +local Types = {} + -- local function NewEnv () - local id, list, mode, step = ID, {} + local id, lists, mode, step = ID, {} ID = ID + 1 + for elem_type in pairs(Types) do + lists[elem_type] = {} + end + return function(a, b, c) if rawequal(a, _command) then -- a: _command, b: what, c: arg if b == "set_mode" then mode = c elseif b == "get_id" then return id - elseif b == "get_list" then - return list + elseif b == "get_lists" then + return lists elseif b == "set_step" then step = c end elseif a == "get_step" then -- a: "get_step" return step - elseif IsArray(a) then -- a: array? + else -- a: element? local env = Stack[Top] assert(env and env(_command, "get_id") == id, "Environment not active") -- is self? - local lower_env = (mode == "parent" or mode == "parent_gc") and Stack[Top - 1] --- TODO: pingpong, pingpong_gc - if lower_env then - lower_env(_command, "get_list")[a] = true - end + local elem_type = Remove(lists, a) - list[a] = nil + if elem_type then + local lower_env = (mode == "parent" or mode == "parent_gc") and Stack[Top - 1] +-- TODO: pingpong, pingpong_gc + if lower_env then + lower_env(_command, "get_lists")[elem_type][a] = true + end - return a + return a + end end end end -- -local function Purge (list) - local nerrs = 0 +local function Purge (lists) + local errs - for arr in pairs(list) do - local ha = arr:get(true) + for elem_type, type_info in pairs(Types) do + local elem_list, cleanup, nerrs = lists[elem_type], type_info.cleanup, 0 - if ha then - local err = af.af_release_array(ha) - - if err ~= af.AF_SUCCESS then + -- + for elem in pairs(elem_list) do + if not cleanup(elem) then nerrs = nerrs + 1 end + + elem_list[elem] = nil end - list[arr] = nil + -- + if nerrs > 0 then + errs = errs or {} + + errs[#errs + 1] = type_info.message:format(nerrs) + end end - return nerrs + return errs and concat(errs, "\n") end -- -- @@ -97,7 +118,7 @@ local function GetResults (env, ok, a, ...) env(_command, "set_mode", nil) - local nerrs = Purge(env(_command, "get_list")) + local errs = Purge(env(_command, "get_lists")) -- Pingpong or normal? (How to end?) Cache[#Cache + 1] = env Top, Stack[Top] = Top - 1 @@ -106,25 +127,22 @@ local function GetResults (env, ok, a, ...) collectgarbage() end -- TODO: pingpong_gc - if ok and nerrs == 0 then + if ok and not errs then return a, ... else -- Clean up if pingpong - error(not ok and a or ("Errors releasing %i arrays"):format(nerrs)) + error(not ok and a or errs) end end -- function M.Add (array_module) - -- Import these here since the array module is not yet registered. - IsArray = array_module.IsArray - -- - function array_module.AddToCurrentEnvironment (arr) + function array_module.AddToCurrentEnvironment (elem_type, arr) local env = Top > 0 and Stack[Top] if env then - env(_command, "get_list")[arr] = true + env(_command, "get_lists")[elem_type][arr] = true end end -- AddOneEnv @@ -154,6 +172,13 @@ function M.Add (array_module) return GetResults(env, pcall(func, env, ...)) end + + -- + function array_module.RegisterEnvironmentCleanup (elem_type, cleanup, message) + assert(Top == 0 and #Cache == 0, "Attempt to register new environment type after launch") + + Types[elem_type] = { cleanup = cleanup, message = message } + end end -- Export the module. diff --git a/scripts/lib/impl/index.lua b/scripts/lib/impl/index.lua index b021ab1..b94fabc 100644 --- a/scripts/lib/impl/index.lua +++ b/scripts/lib/impl/index.lua @@ -8,12 +8,211 @@ local af = require("arrayfire") -- Exports -- local M = {} +-- See also: https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/index.cpp + -- function M.Add (array_module) -- Import these here since the array module is not yet registered. +--[[ +/// +/// \brief Wrapper for af_index. +/// +/// This class is a wrapper for the af_index struct in the C interface. It +/// allows implicit type conversion from valid indexing types like int, +/// \ref af::seq, \ref af_seq, and \ref af::array. +/// +/// \note This is a helper class and does not necessarily need to be created +/// explicitly. It is used in the operator() overloads to simplify the API. +/// +class AFAPI index { + + af_index_t impl; + public: + /// + /// \brief Default constructor. Equivalent to \ref af::span + /// + index(); + ~index(); + + /// + /// \brief Implicit int converter + /// + /// Indexes the af::array at index \p idx + /// + /// \param[in] idx is the id of the index + /// + /// \sa indexing + /// + index(const int idx); + + /// + /// \brief Implicit seq converter + /// + /// Indexes the af::array using an \ref af::seq object + /// + /// \param[in] s0 is the set of indices to parse + /// + /// \sa indexing + /// + index(const af::seq& s0); + + /// + /// \brief Implicit seq converter + /// + /// Indexes the af::array using an \ref af_seq object + /// + /// \param[in] s0 is the set of indices to parse + /// + /// \sa indexing + /// + index(const af_seq& s0); + + /// + /// \brief Implicit int converter + /// + /// Indexes the af::array using an \ref af::array object + /// + /// \param[in] idx0 is the set of indices to parse + /// + /// \sa indexing + /// + index(const af::array& idx0); + +#if AF_API_VERSION >= 31 + /// + /// \brief Copy constructor + /// + /// \param[in] idx0 is index to copy. + /// + /// \sa indexing + /// + index(const index& idx0); +#endif + + /// + /// \brief Returns true if the \ref af::index represents a af::span object + /// + /// \returns true if the af::index is an af::span + /// + bool isspan() const; + + /// + /// \brief Gets the underlying af_index_t object + /// + /// \returns the af_index_t represented by this object + /// + const af_index_t& get() const; +]] + +--[[ +array lookup(const array &in, const array &idx, const int dim) +{ + af_array out = 0; + AF_THROW(af_lookup(&out, in.get(), idx.get(), getFNSD(dim, in.dims()))); + return array(out); +} + +void copy(array &dst, const array &src, + const index &idx0, + const index &idx1, + const index &idx2, + const index &idx3) +{ + unsigned nd = dst.numdims(); + + af_index_t indices[] = {idx0.get(), + idx1.get(), + idx2.get(), + idx3.get()}; + + af_array lhs = dst.get(); + const af_array rhs = src.get(); + AF_THROW(af_assign_gen(&lhs, lhs, nd, indices, rhs)); +} + +index::index() { + impl.idx.seq = af_span; + impl.isSeq = true; + impl.isBatch = false; +} +index::index(const int idx) { + impl.idx.seq = af_make_seq(idx, idx, 1); + impl.isSeq = true; + impl.isBatch = false; +} + +index::index(const af::seq& s0) { + impl.idx.seq = s0.s; + impl.isSeq = true; + impl.isBatch = s0.m_gfor; +} + +index::index(const af_seq& s0) { + impl.idx.seq = s0; + impl.isSeq = true; + impl.isBatch = false; +} + +index::index(const af::array& idx0) { + array idx = idx0.isbool() ? where(idx0) : idx0; + af_array arr = 0; + AF_THROW(af_retain_array(&arr, idx.get())); + impl.idx.arr = arr; + + impl.isSeq = false; + impl.isBatch = false; +} + +index::index(const af::index& idx0) { + *this = idx0; +} + +index::~index() { + if (!impl.isSeq) + af_release_array(impl.idx.arr); +} + +index & index::operator=(const index& idx0) { + impl = idx0.get(); + if(impl.isSeq == false){ + // increment reference count to avoid double free + // when/if idx0 is destroyed + AF_THROW(af_retain_array(&impl.idx.arr, impl.idx.arr)); + } + return *this; +} + +#if __cplusplus > 199711L +index::index(index &&idx0) { + impl = idx0.impl; +} + +index& index::operator=(index &&idx0) { + impl = idx0.impl; + return *this; +} +#endif + + +static bool operator==(const af_seq& lhs, const af_seq& rhs) { + return lhs.begin == rhs.begin && lhs.end == rhs.end && lhs.step == rhs.step; +} + +bool index::isspan() const +{ + return impl.isSeq == true && impl.idx.seq == af_span; +} + +const af_index_t& index::get() const +{ + return impl; +} +]] end +-- TODO: Add "index" environment type? + -- Export the module. return M \ No newline at end of file diff --git a/scripts/lib/impl/seq.lua b/scripts/lib/impl/seq.lua index baff2d7..8829fca 100644 --- a/scripts/lib/impl/seq.lua +++ b/scripts/lib/impl/seq.lua @@ -164,5 +164,7 @@ end SeqMT.__index = SeqMT SeqMT.__metatable = MetaValue +-- TODO: Add "seq" environment type? + -- Export the module. return M \ No newline at end of file