Skip to content

Commit

Permalink
Rest of array methods
Browse files Browse the repository at this point in the history
Added 5.3 operators
  • Loading branch information
ggcrunchy committed Dec 20, 2015
1 parent 355f8c7 commit bdaf733
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 16 deletions.
21 changes: 14 additions & 7 deletions scripts/lib/impl/operators.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
local af = require("arrayfire")

-- Forward declarations --
local CallWrap
local GetLib
local TwoArrays

Expand Down Expand Up @@ -35,6 +36,7 @@ end
function M.Add (array_module, meta)
-- Import these here since the array module is not yet registered.
GetLib = array_module.GetLib
CallWrap = array_module.CallWrap
TwoArrays = array_module.TwoArrays

--
Expand All @@ -49,30 +51,35 @@ function M.Add (array_module, meta)
--
for k, v in pairs{
__add = Binary("add"),
__band = Binary("bitand"),
__bnot = function(a)
return CallWrap("af_not", a:get())
end,
__bor = Binary("bitor"),
__bxor = Binary("bitxor"),
__call = function(a, ...)
-- operator()... ugh (proxy types, __index and __newindex shenanigans)
end,
__div = Binary("div"),
__eq = Binary("eq", true),
__lt = Binary("lt", true),
__le = Binary("le", true),
__mod = Binary("mod"),
__mod = Binary("rem"),
__mul = Binary("mul"),
--[[
__newindex = function(arr, k, v)
__newindex = function(a, k, v)
-- TODO: disable for non-proxies?

if k == "_" then
-- lvalue assign of v
end
end
]]
end,
__pow = Binary("pow"),
__shl = Binary("bitshiftl"),
__shr = Binary("bitshiftr"),
__sub = Binary("sub"),
__unm = function(a)
return 0 - a
end,
-- TODO: 5.3 supports bitwise ops...
end
} do
meta[k] = v
end
Expand Down
78 changes: 69 additions & 9 deletions scripts/lib/methods/methods.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@ function M.Add (array_module, meta)
local CallWrap = array_module.CallWrap
local GetLib = array_module.GetLib

--
local function Wrap (name)
name = "af_" .. name

return function(arr)
return CallWrap(name, arr:get())
end
end

-- --
local SizeOf = {}

for prefix, size in ("f32 f64 s32 u32 s64 u64 u8 b8 c32 c64 s16 u16"):gmatch "(%a)(%d+)" do
local k = af[prefix .. size]

if k then -- account for earlier versions
SizeOf[k] = tonumber(size) / (prefix == "c" and 4 or 8) -- 8 bits to a byte; double complex types
end
end

--
for k, v in pairs{
--
Expand All @@ -25,10 +45,16 @@ function M.Add (array_module, meta)
end,

--
copy = function(arr)
return CallWrap("af_copy_array", arr:get())
bytes = function(arr)
local ha = arr:get()
local n, dtype = Call("af_get_elements", ha), Call("af_get_type", ha)

return n * (SizeOf[dtype] or 4)
end,

--
copy = Wrap("copy_array"),

--
dims = function(arr, i)
if i then
Expand All @@ -39,18 +65,50 @@ function M.Add (array_module, meta)
end,

--
elements = function(arr)
return Call("af_get_elements", arr:get())
end,
elements = Wrap("get_elements"),

--
eval = function(arr)
Call("af_eval", arr:get())
end,
eval = Wrap("eval"),

--
get = array_module.GetHandle,

--
isbool = Wrap("is_bool"),

--
iscolumn = Wrap("is_column"),

--
iscomplex = Wrap("is_complex"),

--
isdouble = Wrap("is_double"),

--
isempty = Wrap("is_empty"),

--
isfloating = Wrap("is_floating"),

--
isinteger = Wrap("is_integer"),

--
isrealfloating = Wrap("is_real_floating"),

--
isrow = Wrap("is_row"),

--
isscalar = Wrap("is_scalar"),

--
issingle = Wrap("is_single"),

--
isvector = Wrap("is_vector"),

--
H = function(arr)
return GetLib().transpose(arr, true)
Expand All @@ -67,7 +125,9 @@ function M.Add (array_module, meta)
--
T = function(arr)
return GetLib().transpose(arr)
end
end,

type = Wrap("get_type")
} do
meta[k] = v
end
Expand Down

0 comments on commit bdaf733

Please sign in to comment.