Skip to content

Commit

Permalink
Fix compatibility with dask's newest einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Nov 28, 2024
1 parent 8478fd9 commit 994643d
Showing 1 changed file with 44 additions and 1 deletion.
45 changes: 44 additions & 1 deletion einx/backend/_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,50 @@ def reshape(tensor, shape):

transpose = op.transpose(tda.transpose)
broadcast_to = op.broadcast_to(tda.broadcast_to)
einsum = op.einsum(tda.einsum)

@classmethod
@einx.trace
def einsum(backend, equation, *operands):
exprs = equation.split("->")
if len(exprs) != 2:
raise ValueError("Invalid einsum equation")
in_exprs = exprs[0].split(",")
out_expr = exprs[1]

# Remove scalars
scalars = []
for in_expr, operand in zip(in_exprs, operands):
if (len(in_expr) == 0) != (operand.shape == ()):
raise ValueError(
f"Tensor and einsum expression do not match: {in_expr} and {operand.shape}"
)
if operand.shape == ():
scalars.append(operand)
operands = [operand for operand in operands if operand.shape != ()]
in_exprs = [in_expr for in_expr in in_exprs if len(in_expr) > 0]
assert len(in_exprs) == len(operands)
equation = ",".join(in_exprs) + "->" + out_expr

# Call without scalars
if len(operands) == 1:
if in_exprs[0] != out_expr:
output = op.einsum(tda.einsum)(equation, *operands)
else:
output = operands[0]
elif len(operands) > 1:
output = op.einsum(tda.einsum)(equation, *operands)
else:
output = None

# Multiply scalars
if len(scalars) > 0:
if output is None:
output = backend.multiply(*scalars)
else:
output = backend.multiply(output, *scalars)

return output

arange = op.arange(tda.arange)

@staticmethod
Expand Down

0 comments on commit 994643d

Please sign in to comment.