Skip to content

Commit

Permalink
Merge pull request #86 from mdhaber/matmul_speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucascolley authored Jan 28, 2025
2 parents 4711da6 + ce3f7f8 commit d488066
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions marray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,16 +358,15 @@ def get_linalg_fun(name):
def linalg_fun(x1, x2, /, **kwargs):
x1 = asarray(x1)
x2 = asarray(x2)
data1 = xp.asarray(x1.data, copy=True)
data2 = xp.asarray(x2.data, copy=True)
data1[x1.mask] = 0
data2[x2.mask] = 0
zero = xp.asarray(0)
data1 = xp.where(x1.mask, xp.asarray(0, dtype=x1.dtype), x1.data)
data2 = xp.where(x2.mask, xp.asarray(0, dtype=x2.dtype), x2.data)
fun = getattr(xp, name)
data = fun(data1, data2, **kwargs)
# Strict array can't do arithmetic with booleans
# mask = ~fun(~x1.mask, ~x2.mask)
mask = fun(xp.astype(~x1.mask, xp.uint64),
xp.astype(~x2.mask, xp.uint64), **kwargs)
mask = fun(xp.astype(~x1.mask, xp.float32),
xp.astype(~x2.mask, xp.float32), **kwargs)
mask = ~xp.astype(mask, xp.bool)
return MArray(data, mask)
return linalg_fun
Expand Down

0 comments on commit d488066

Please sign in to comment.