Skip to content

Commit

Permalink
Provisional gfor() iterator, calling gforGet() in binary operators an…
Browse files Browse the repository at this point in the history
…d math ops

Use in rainfall sample (but need array proxies to finish)

Stub for index objects

Sprinkling of notes about how to add array proxies, tentative idea being to integrate them as arrays with different GetHandle() and SetHandle() behavior, and various operations disabled
  • Loading branch information
ggcrunchy committed Dec 18, 2015
1 parent 8f2979e commit 946b285
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 11 deletions.
1 change: 1 addition & 0 deletions scripts/lib/af_lib.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ for _, v in ipairs{
"funcs.statistics",
"funcs.util",
"funcs.vector",
"funcs.gfor",
"graphics.window",
"methods.constructors",
"methods.device",
Expand Down
68 changes: 68 additions & 0 deletions scripts/lib/funcs/gfor.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
--- gfor() mechanism.

-- Standard library imports --
local assert = assert

-- Modules --
local array = require("lib.impl.array")

-- Imports --
local GetLib = array.GetLib

-- Exports --
local M = {}

--- See also: https://github.com/arrayfire/arrayfire/blob/devel/include/af/gfor.h
-- https://github.com/arrayfire/arrayfire/blob/devel/src/api/cpp/gfor.cpp

-- --
local Status = false

--
local function AuxGfor (seq)
Status = not Status

if Status then
return seq
end
end

--
function M.Add (into)
for k, v in pairs{
--
batchFunc = function(lhs, rhs, func)
assert(not Status, "batchFunc can not be used inside GFOR") -- TODO: AF_ERR_ARG

Status = true

local res = func(lhs, rhs)

Status = false

return res
end,

--
gfor = function(...)
local lib = GetLib()

return AuxGfor, lib.seq(lib.seq(...), true)
end,

--
gforGet = function()
return Status
end,

--
gforSet = function(val)
Status = not not val
end
} do
into[k] = v
end
end

-- Export the module.
return M
3 changes: 2 additions & 1 deletion scripts/lib/funcs/mathematics.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ local array = require("lib.impl.array")

-- Imports --
local CallWrap = array.CallWrap
local GetLib = array.GetLib
local IsArray = array.IsArray
local TwoArrays = array.TwoArrays

Expand All @@ -17,7 +18,7 @@ local M = {}
--
local function Binary (name)
return function(a, b)
return TwoArrays(name, a, b--[[TODO: IsArray(a) and IsArray(b) and gfor_get]])
return TwoArrays(name, a, b, IsArray(a) and IsArray(b) and GetLib().gforGet())
end
end

Expand Down
3 changes: 3 additions & 0 deletions scripts/lib/impl/array.lua
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ end
-- @bool remove
-- @treturn ?|af_array|nil X
function M.GetHandle (arr, remove)
-- TODO: If proxy, add reference?
local ha = arr.m_handle

if remove then
Expand Down Expand Up @@ -191,6 +192,7 @@ end
-- @tparam LuaArray arr
-- @tparam ?|af_array|nil handle
function M.SetHandle (arr, handle)
-- TODO: disable for proxies
local cur = arr.m_handle

if cur ~= nil then
Expand Down Expand Up @@ -298,6 +300,7 @@ for _, v in ipairs{
"lib.impl.ephemeral",
"lib.impl.operators",
"lib.impl.seq",
"lib.impl.index", -- depends on seq
"lib.methods.methods"
} do
require(v).Add(M, ArrayMethodsAndMetatable)
Expand Down
19 changes: 19 additions & 0 deletions scripts/lib/impl/index.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
--- Core index module.

-- Modules --
local af = require("arrayfire")

-- Forward declarations --

-- Exports --
local M = {}

--
function M.Add (array_module)
-- Import these here since the array module is not yet registered.


end

-- Export the module.
return M
14 changes: 13 additions & 1 deletion scripts/lib/impl/operators.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
local af = require("arrayfire")

-- Forward declarations --
local GetLib
local TwoArrays

-- Exports --
Expand All @@ -17,9 +18,10 @@ local function Binary (name, cmp)
name = "af_" .. name

return function(a, b)
-- TODO: disable for proxies?
Result = nil

local arr = TwoArrays(name, a, b, true) -- TODO: gforGet()
local arr = TwoArrays(name, a, b, GetLib().gforGet())

if cmp then
Result = arr
Expand All @@ -32,6 +34,7 @@ end
--
function M.Add (array_module, meta)
-- Import these here since the array module is not yet registered.
GetLib = array_module.GetLib
TwoArrays = array_module.TwoArrays

--
Expand All @@ -55,6 +58,15 @@ function M.Add (array_module, meta)
__le = Binary("le", true),
__mod = Binary("mod"),
__mul = Binary("mul"),
--[[
__newindex = function(arr, k, v)
-- TODO: disable for non-proxies?
if k == "_" then
-- lvalue assign of v
end
end
]]
__pow = Binary("pow"),
__sub = Binary("sub"),
__unm = function(a)
Expand Down
16 changes: 7 additions & 9 deletions scripts/tests/getting_started/rainfall.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,20 @@ AF.main(function()
local site = AF.array(n,site_)
local measurement = AF.array(n,measurement_)
local rainfall = AF.constant(0, sites)
--[[
gfor (seq s, sites) {
rainfall(s) = sum(measurement * (site == s));
}
]]

for s in AF.gfor(sites) do
-- rainfall(s) = AF.sum(measurement * COMP(site == AF.array(s)))
end
print("total rainfall at each site:")
AF.print("rainfall", rainfall)
local is_between = AF["and"](Comp(WC(1) <= day), Comp(day <= WC(5))) -- days 1 and 5
local rain_between = AF.sum("f32", measurement * is_between)
AF.printf("rain between days: %g", rain_between)
AF.printf("number of days with rain: %g", AF.sum("f32", Comp(AF.diff1(day) > WC(0))) + 1)
local per_day = AF.constant(0, days)
--[[
gfor (seq d, days)
per_day(d) = sum(measurement * (day == d))
]]
for d in AF.gfor(days) do
-- per_day(d) = AF.sum(measurement * COMP(day == AF.array(d)))
end
print("total rainfall each day:")
AF.print("per_day", per_day)
AF.printf("number of days over five: %g", AF.sum("f32", Comp(per_day > WC(5))))
Expand Down

0 comments on commit 946b285

Please sign in to comment.