Skip to content

Commit

Permalink
ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Jun 11, 2024
1 parent dbe39a5 commit a624bb6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
8 changes: 5 additions & 3 deletions einx/backend/_tinygrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def scalar_to_tensor(x):
)
else:
return x

def elementwise(func, convert_all_to_tensor=False):
@einx.trace
@functools.wraps(func)
Expand All @@ -31,6 +31,7 @@ def outer(*args):
args = [a for a in args]
args[0] = scalar_to_tensor(args[0])
return op.elementwise(func)(*args)

return outer

def reduce(func):
Expand All @@ -53,6 +54,7 @@ def reduce(tensor, axis=None, **kwargs):
if "keepdims" in kwargs:
kwargs["keepdim"] = kwargs.pop("keepdims")
return tracer.apply(func, args=[tensor], kwargs=kwargs, output=tracer.Tensor(shape))

return reduce

def to_dtype(x):
Expand Down Expand Up @@ -114,7 +116,7 @@ def einsum(backend, equation, *tensors):
scalars = scalars[1:]
for scalar in scalars:
x = backend.multiply(x, scalar)

return x

@staticmethod
Expand Down Expand Up @@ -196,7 +198,7 @@ def subtract_at(tensor, coordinates, updates):
@staticmethod
@einx.trace
def stop_gradient(tensor):
return tensor # TODO: set requires_grad to False?
return tensor # TODO: set requires_grad to False?

@staticmethod
@einx.trace
Expand Down
5 changes: 4 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,13 +255,16 @@ def wrap(op):

if importlib.util.find_spec("tinygrad"):
import os

os.environ["PYTHON"] = "1"
from tinygrad import Tensor

backend = einx.backend.tinygrad.create()

test = types.SimpleNamespace(
full=lambda shape, value=0.0, dtype="float32": Tensor.full(shape, value, dtype=backend.to_dtype(dtype)),
full=lambda shape, value=0.0, dtype="float32": Tensor.full(
shape, value, dtype=backend.to_dtype(dtype)
),
to_tensor=Tensor,
to_numpy=lambda x: x.numpy(),
)
Expand Down

0 comments on commit a624bb6

Please sign in to comment.