Skip to content

Commit

Permalink
Add pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Jun 27, 2024
1 parent 0243839 commit 07aca8c
Show file tree
Hide file tree
Showing 28 changed files with 128 additions and 94 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
2 changes: 1 addition & 1 deletion einx/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .register import register_for_module, register, get, backends, numpy
from .register import register_for_module, register, get, backends
from .base import Backend, get_default

from . import _numpy as numpy
Expand Down
5 changes: 3 additions & 2 deletions einx/backend/_dask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .base import *
from .base import Backend, associative_binary_to_nary
import einx.tracer as tracer
from einx.tracer.tensor import op
import einx, types
import einx
import types
from functools import partial


Expand Down
5 changes: 3 additions & 2 deletions einx/backend/_jax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .base import *
from .base import Backend, associative_binary_to_nary
import einx.tracer as tracer
from einx.tracer.tensor import op
import einx, types
import einx
import types
from functools import partial


Expand Down
5 changes: 3 additions & 2 deletions einx/backend/_mlx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .base import *
from .base import Backend, associative_binary_to_nary
import einx.tracer as tracer
from einx.tracer.tensor import op
import einx, types
import einx
import types
from functools import partial


Expand Down
5 changes: 3 additions & 2 deletions einx/backend/_numpy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base import *
from .base import Backend, associative_binary_to_nary, vmap_forloop
import einx.tracer as tracer
from einx.tracer.tensor import op
import numpy as np
import einx, types
import einx
import types
from functools import partial


Expand Down
5 changes: 3 additions & 2 deletions einx/backend/_tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .base import *
from .base import Backend, associative_binary_to_nary
import einx.tracer as tracer
from einx.tracer.tensor import op
import einx, types
import einx
import types
from functools import partial


Expand Down
9 changes: 5 additions & 4 deletions einx/backend/_tinygrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .base import *
from .base import Backend, associative_binary_to_nary
import einx.tracer as tracer
from einx.tracer.tensor import op
import einx, types
import einx
import types
from functools import partial
import functools

Expand All @@ -28,7 +29,7 @@ def outer(*args):
if convert_all_to_tensor:
args = [scalar_to_tensor(a) for a in args]
else:
args = [a for a in args]
args = list(args)
args[0] = scalar_to_tensor(args[0])
return op.elementwise(func)(*args)

Expand Down Expand Up @@ -95,7 +96,7 @@ def einsum(backend, equation, *tensors):
if len(inputs) != len(tensors):
raise ValueError("Invalid equation")
inputs = [x.strip().replace(" ", "") for x in inputs]
tensors = [t for t in tensors]
tensors = list(tensors)

scalars = []
for i in list(range(len(inputs)))[::-1]:
Expand Down
10 changes: 5 additions & 5 deletions einx/backend/_torch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .base import *
from .base import Backend, associative_binary_to_nary, ErrorBackend
import einx.tracer as tracer
from einx.tracer.tensor import op
import einx, types
import einx
import types
from functools import partial
import functools


def create():
Expand Down Expand Up @@ -47,8 +49,6 @@ def wrapper(*args, **kwargs):

return wrapper

MARKER_DECORATED_CONSTRUCT_GRAPH = "__einx_decorated_construct_graph"

ttorch = tracer.import_("torch")
import torch as torch_

Expand Down Expand Up @@ -98,7 +98,7 @@ class torch(Backend):
@staticmethod
@einx.trace
def to_tensor(arg, shape):
assert False
raise NotImplementedError("to_tensor is not implemented for PyTorch")

@staticmethod
@einx.trace
Expand Down
8 changes: 4 additions & 4 deletions einx/backend/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def register_for_module(module_name, backend_factory):
register(backend_factory())
else:
# Module is not yet imported -> register factory
if not module_name in backend_factories:
if module_name not in backend_factories:
backend_factories[module_name] = []
backend_factories[module_name].append(backend_factory)

Expand Down Expand Up @@ -68,7 +68,7 @@ def _update():

def _get1(tensor):
backend = tensortype_to_backend.get(type(tensor), None)
if not backend is None:
if backend is not None:
return backend

_update()
Expand Down Expand Up @@ -103,7 +103,7 @@ def get(arg):
for tensor in tensors:
if tensor is not None:
backend2 = _get1(tensor)
if not backend2 is None:
if backend2 is not None:
if (
backend is not None
and backend != backend2
Expand All @@ -117,6 +117,6 @@ def get(arg):
if backend is None or backend2 != numpy:
backend = backend2
if backend is None:
raise ValueError(f"Could not determine the backend to use in this operation")
raise ValueError("Could not determine the backend to use in this operation")
else:
return backend
2 changes: 1 addition & 1 deletion einx/expr/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __str__(self):
return " + ".join(str(c) for c in self.children)

def sympy(self):
return sum([c.sympy() for c in self.children])
return sum(c.sympy() for c in self.children)


class Product(Expression):
Expand Down
2 changes: 1 addition & 1 deletion einx/expr/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ def replace(expr):
while i < len(expr):
# Check if a subexpression starts at position i
exprlist_found = None
for idx, common_expr in enumerate(common_exprs):
for idx, common_expr in enumerate(common_exprs): # noqa: B007
for exprlist in common_expr:
for j in range(len(exprlist)):
if i + j >= len(expr) or id(exprlist[j]) != id(expr[i + j]):
Expand Down
6 changes: 3 additions & 3 deletions einx/nn/equinox.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __init__(self, name, init, dtype):
self.dtype = dtype

def __call__(self, shape, kwargs):
name = self.name if not self.name is None else kwargs.get("name", None)
init = self.init if not self.init is None else kwargs.get("init", None)
dtype = self.dtype if not self.dtype is None else kwargs.get("dtype", None)
name = self.name if self.name is not None else kwargs.get("name", None)
init = self.init if self.init is not None else kwargs.get("init", None)
dtype = self.dtype if self.dtype is not None else kwargs.get("dtype", None)

if name is None:
raise ValueError("Must specify name for tensor factory eqx.Module")
Expand Down
6 changes: 3 additions & 3 deletions einx/nn/flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def __init__(self, name, init, dtype, col, param_type):
self.param_type = param_type

def __call__(self, shape, kwargs):
name = self.name if not self.name is None else kwargs.get("name", None)
init = self.init if not self.init is None else kwargs.get("init", None)
dtype = self.dtype if not self.dtype is None else kwargs.get("dtype", None)
name = self.name if self.name is not None else kwargs.get("name", None)
init = self.init if self.init is not None else kwargs.get("init", None)
dtype = self.dtype if self.dtype is not None else kwargs.get("dtype", None)
col = self.col

if name is None:
Expand Down
8 changes: 4 additions & 4 deletions einx/nn/haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(self, name, init, dtype, param_type, depend_on):
self.depend_on = depend_on

def __call__(self, shape, kwargs):
name = self.name if not self.name is None else kwargs.get("name", None)
init = self.init if not self.init is None else kwargs.get("init", None)
dtype = self.dtype if not self.dtype is None else kwargs.get("dtype", None)
name = self.name if self.name is not None else kwargs.get("name", None)
init = self.init if self.init is not None else kwargs.get("init", None)
dtype = self.dtype if self.dtype is not None else kwargs.get("dtype", None)

if name is None:
raise ValueError("Must specify name for tensor factory hk.get_{parameter|state}")
Expand Down Expand Up @@ -90,7 +90,7 @@ def __call__(self, shape, kwargs):
elif self.param_type == "state":
func = thk.get_state
else:
assert False
raise AssertionError(f"Unknown parameter type '{self.param_type}'")

return einx.tracer.apply(
func,
Expand Down
6 changes: 3 additions & 3 deletions einx/nn/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def __init__(self, name, init, dtype, trainable):
self.trainable = trainable

def __call__(self, shape, kwargs):
name = self.name if not self.name is None else kwargs.get("name", None)
init = self.init if not self.init is None else kwargs.get("init", None)
dtype = self.dtype if not self.dtype is None else kwargs.get("dtype", None)
name = self.name if self.name is not None else kwargs.get("name", None)
init = self.init if self.init is not None else kwargs.get("init", None)
dtype = self.dtype if self.dtype is not None else kwargs.get("dtype", None)

if name is None:
raise ValueError("Must specify name for tensor factory keras.layers.Layer")
Expand Down
2 changes: 1 addition & 1 deletion einx/nn/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, init):
self.init = init

def __call__(self, shape, kwargs):
init = self.init if not self.init is None else kwargs.get("init", None)
init = self.init if self.init is not None else kwargs.get("init", None)

x = self

Expand Down
2 changes: 1 addition & 1 deletion einx/op/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _unflatten(exprs_in, tensors_in, expr_out, backend):
def unflatten(exprs_in, tensors_in, exprs_out, *, backend):
if len(exprs_in) != len(tensors_in):
raise ValueError("Got different number of input expressions and tensors")
assert not backend is None
assert backend is not None

iter_exprs_in = iter(exprs_in)
iter_tensors_in = iter(tensors_in)
Expand Down
Loading

0 comments on commit 07aca8c

Please sign in to comment.