From bdaf733e9d0d1082e3b10a9d7dabc47c0622be25 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Sat, 19 Dec 2015 21:04:19 -0600 Subject: [PATCH] Rest of array methods Added 5.3 operators --- scripts/lib/impl/operators.lua | 21 ++++++--- scripts/lib/methods/methods.lua | 78 +++++++++++++++++++++++++++++---- 2 files changed, 83 insertions(+), 16 deletions(-) diff --git a/scripts/lib/impl/operators.lua b/scripts/lib/impl/operators.lua index 580339a..a34708a 100644 --- a/scripts/lib/impl/operators.lua +++ b/scripts/lib/impl/operators.lua @@ -4,6 +4,7 @@ local af = require("arrayfire") -- Forward declarations -- +local CallWrap local GetLib local TwoArrays @@ -35,6 +36,7 @@ end function M.Add (array_module, meta) -- Import these here since the array module is not yet registered. GetLib = array_module.GetLib + CallWrap = array_module.CallWrap TwoArrays = array_module.TwoArrays -- @@ -49,6 +51,12 @@ function M.Add (array_module, meta) -- for k, v in pairs{ __add = Binary("add"), + __band = Binary("bitand"), + __bnot = function(a) + return CallWrap("af_not", a:get()) + end, + __bor = Binary("bitor"), + __bxor = Binary("bitxor"), __call = function(a, ...) -- operator()... ugh (proxy types, __index and __newindex shenanigans) end, @@ -56,23 +64,22 @@ function M.Add (array_module, meta) __eq = Binary("eq", true), __lt = Binary("lt", true), __le = Binary("le", true), - __mod = Binary("mod"), + __mod = Binary("rem"), __mul = Binary("mul"), ---[[ - __newindex = function(arr, k, v) + __newindex = function(a, k, v) -- TODO: disable for non-proxies? if k == "_" then -- lvalue assign of v end - end -]] + end, __pow = Binary("pow"), + __shl = Binary("bitshiftl"), + __shr = Binary("bitshiftr"), __sub = Binary("sub"), __unm = function(a) return 0 - a - end, - -- TODO: 5.3 supports bitwise ops... + end } do meta[k] = v end diff --git a/scripts/lib/methods/methods.lua b/scripts/lib/methods/methods.lua index eed755c..654662b 100644 --- a/scripts/lib/methods/methods.lua +++ b/scripts/lib/methods/methods.lua @@ -17,6 +17,26 @@ function M.Add (array_module, meta) local CallWrap = array_module.CallWrap local GetLib = array_module.GetLib + -- + local function Wrap (name) + name = "af_" .. name + + return function(arr) + return CallWrap(name, arr:get()) + end + end + + -- -- + local SizeOf = {} + + for prefix, size in ("f32 f64 s32 u32 s64 u64 u8 b8 c32 c64 s16 u16"):gmatch "(%a)(%d+)" do + local k = af[prefix .. size] + + if k then -- account for earlier versions + SizeOf[k] = tonumber(size) / (prefix == "c" and 4 or 8) -- 8 bits to a byte; double complex types + end + end + -- for k, v in pairs{ -- @@ -25,10 +45,16 @@ function M.Add (array_module, meta) end, -- - copy = function(arr) - return CallWrap("af_copy_array", arr:get()) + bytes = function(arr) + local ha = arr:get() + local n, dtype = Call("af_get_elements", ha), Call("af_get_type", ha) + + return n * (SizeOf[dtype] or 4) end, + -- + copy = Wrap("copy_array"), + -- dims = function(arr, i) if i then @@ -39,18 +65,50 @@ function M.Add (array_module, meta) end, -- - elements = function(arr) - return Call("af_get_elements", arr:get()) - end, + elements = Wrap("get_elements"), -- - eval = function(arr) - Call("af_eval", arr:get()) - end, + eval = Wrap("eval"), -- get = array_module.GetHandle, + -- + isbool = Wrap("is_bool"), + + -- + iscolumn = Wrap("is_column"), + + -- + iscomplex = Wrap("is_complex"), + + -- + isdouble = Wrap("is_double"), + + -- + isempty = Wrap("is_empty"), + + -- + isfloating = Wrap("is_floating"), + + -- + isinteger = Wrap("is_integer"), + + -- + isrealfloating = Wrap("is_real_floating"), + + -- + isrow = Wrap("is_row"), + + -- + isscalar = Wrap("is_scalar"), + + -- + issingle = Wrap("is_single"), + + -- + isvector = Wrap("is_vector"), + -- H = function(arr) return GetLib().transpose(arr, true) @@ -67,7 +125,9 @@ function M.Add (array_module, meta) -- T = function(arr) return GetLib().transpose(arr) - end + end, + + type = Wrap("get_type") } do meta[k] = v end