Skip to content

Commit

Permalink
Added where()
Browse files Browse the repository at this point in the history
Copied over C++ code for index... still puzzling out a good attack :P

Allow subtypes in ephemeral environments (currently just arrays, but meant to support proxies, sequences, and indices at least)
  • Loading branch information
ggcrunchy committed Dec 20, 2015
1 parent c2dbe91 commit 355f8c7
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 50 deletions.
11 changes: 10 additions & 1 deletion scripts/lib/funcs/vector.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
--- Vector functions.

-- Standard library imports --
local assert = assert
local min = math.min
local select = select
local type = type
Expand All @@ -20,6 +21,7 @@ local ToType = array.ToType
local M = {}

-- See also: https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/reduce.cpp
-- https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/where.cpp

--
local function Bool (value)
Expand Down Expand Up @@ -172,7 +174,14 @@ function M.Add (into)
end,

--
sum = ReduceNaN("sum")
sum = ReduceNaN("sum"),

--
where = function(in_arr)
assert(not GetLib().gforGet(), "WHERE can not be used inside GFOR") -- TODO: AF_ERR_RUNTIME);

return CallWrap("af_where", in_arr:get())
end
} do
into[k] = v
end
Expand Down
33 changes: 19 additions & 14 deletions scripts/lib/impl/array.lua
Original file line number Diff line number Diff line change
Expand Up @@ -139,17 +139,10 @@ end

--- DOCME
-- @tparam LuaArray arr
-- @bool remove
-- @treturn ?|af_array|nil X
function M.GetHandle (arr, remove)
function M.GetHandle (arr)
-- TODO: If proxy, add reference?
local ha = arr.m_handle

if remove then
arr.m_handle = nil
end

return ha
return arr.m_handle
end

-- --
Expand Down Expand Up @@ -188,6 +181,8 @@ function M.IsConstant (item)
return not not Constants[item] -- metatable redundant; coerce to false if missing
end

-- TODO: IsProxy(), MakeProxy()...

--- DOCME
-- @tparam LuaArray arr
-- @tparam ?|af_array|nil handle
Expand Down Expand Up @@ -237,9 +232,8 @@ function M.ToType (ret_type, real, imag)
if rtype == "c32" or rtype == "c64" then
return { real = real, imag = imag }
else
return real
return real -- TODO: Improve this!
end
-- TODO: Improve these a bit
end

--- DOCME
Expand Down Expand Up @@ -279,7 +273,7 @@ end
function M.WrapArray (arr)
local wrapped = setmetatable({ m_handle = arr }, ArrayMethodsAndMetatable)

_AddToCurrentEnvironment_(wrapped)
_AddToCurrentEnvironment_("array", wrapped)

return wrapped
end
Expand All @@ -299,8 +293,8 @@ end
for _, v in ipairs{
"lib.impl.ephemeral",
"lib.impl.operators",
"lib.impl.seq",
"lib.impl.index", -- depends on seq
"lib.impl.seq", -- depends on ephemeral
"lib.impl.index", -- depends on ephemeral, seq
"lib.methods.methods"
} do
require(v).Add(M, ArrayMethodsAndMetatable)
Expand All @@ -309,6 +303,17 @@ end
ArrayMethodsAndMetatable.__index = ArrayMethodsAndMetatable
ArrayMethodsAndMetatable.__metatable = MetaValue

-- Register array environment type.
M.RegisterEnvironmentCleanup("array", function(arr)
local ha = arr:get()

arr.m_handle = nil -- set() can error out

return not ha or af.af_release_array(ha) == SUCCESS
-- TODO: pooling?
end, "Errors releasing %i arrays")
-- TODO: Register "array_proxy"?

-- By default, check valid names.
M.CheckNames(true)

Expand Down
95 changes: 60 additions & 35 deletions scripts/lib/impl/ephemeral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@

-- Standard library imports --
local assert = assert
local concat = table.concat
local collectgarbage = collectgarbage
local error = error
local pairs = pairs
local pcall = pcall
local rawequal = rawequal
local remove = table.remove

-- Modules --
local af = require("arrayfire")

-- Forward declarations --
local IsArray

-- Cookies --
local _command = {}

Expand All @@ -30,62 +25,88 @@ local Stack = {}
-- --
local Top = 0

--
local function Remove (lists, elem)
for elem_type, list in pairs(lists) do
if list[elem] then
list[elem] = nil

return elem_type
end
end
end

-- --
local Types = {}

--
local function NewEnv ()
local id, list, mode, step = ID, {}
local id, lists, mode, step = ID, {}

ID = ID + 1

for elem_type in pairs(Types) do
lists[elem_type] = {}
end

return function(a, b, c)
if rawequal(a, _command) then -- a: _command, b: what, c: arg
if b == "set_mode" then
mode = c
elseif b == "get_id" then
return id
elseif b == "get_list" then
return list
elseif b == "get_lists" then
return lists
elseif b == "set_step" then
step = c
end
elseif a == "get_step" then -- a: "get_step"
return step
elseif IsArray(a) then -- a: array?
else -- a: element?
local env = Stack[Top]

assert(env and env(_command, "get_id") == id, "Environment not active") -- is self?

local lower_env = (mode == "parent" or mode == "parent_gc") and Stack[Top - 1]
-- TODO: pingpong, pingpong_gc
if lower_env then
lower_env(_command, "get_list")[a] = true
end
local elem_type = Remove(lists, a)

list[a] = nil
if elem_type then
local lower_env = (mode == "parent" or mode == "parent_gc") and Stack[Top - 1]
-- TODO: pingpong, pingpong_gc
if lower_env then
lower_env(_command, "get_lists")[elem_type][a] = true
end

return a
return a
end
end
end
end

--
local function Purge (list)
local nerrs = 0
local function Purge (lists)
local errs

for arr in pairs(list) do
local ha = arr:get(true)
for elem_type, type_info in pairs(Types) do
local elem_list, cleanup, nerrs = lists[elem_type], type_info.cleanup, 0

if ha then
local err = af.af_release_array(ha)

if err ~= af.AF_SUCCESS then
--
for elem in pairs(elem_list) do
if not cleanup(elem) then
nerrs = nerrs + 1
end

elem_list[elem] = nil
end

list[arr] = nil
--
if nerrs > 0 then
errs = errs or {}

errs[#errs + 1] = type_info.message:format(nerrs)
end
end

return nerrs
return errs and concat(errs, "\n")
end

-- --
Expand All @@ -97,7 +118,7 @@ local function GetResults (env, ok, a, ...)

env(_command, "set_mode", nil)

local nerrs = Purge(env(_command, "get_list"))
local errs = Purge(env(_command, "get_lists"))
-- Pingpong or normal? (How to end?)
Cache[#Cache + 1] = env
Top, Stack[Top] = Top - 1
Expand All @@ -106,25 +127,22 @@ local function GetResults (env, ok, a, ...)
collectgarbage()
end
-- TODO: pingpong_gc
if ok and nerrs == 0 then
if ok and not errs then
return a, ...
else
-- Clean up if pingpong
error(not ok and a or ("Errors releasing %i arrays"):format(nerrs))
error(not ok and a or errs)
end
end

--
function M.Add (array_module)
-- Import these here since the array module is not yet registered.
IsArray = array_module.IsArray

--
function array_module.AddToCurrentEnvironment (arr)
function array_module.AddToCurrentEnvironment (elem_type, arr)
local env = Top > 0 and Stack[Top]

if env then
env(_command, "get_list")[arr] = true
env(_command, "get_lists")[elem_type][arr] = true
end
end
-- AddOneEnv
Expand Down Expand Up @@ -154,6 +172,13 @@ function M.Add (array_module)

return GetResults(env, pcall(func, env, ...))
end

--
function array_module.RegisterEnvironmentCleanup (elem_type, cleanup, message)
assert(Top == 0 and #Cache == 0, "Attempt to register new environment type after launch")

Types[elem_type] = { cleanup = cleanup, message = message }
end
end

-- Export the module.
Expand Down
Loading

0 comments on commit 355f8c7

Please sign in to comment.