Skip to content

Commit

Permalink
Overhaul of docs and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Jun 4, 2024
1 parent ee99cd2 commit 73dd98c
Show file tree
Hide file tree
Showing 32 changed files with 715 additions and 781 deletions.
29 changes: 12 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
# *einx* - Tensor Operations in Einstein-Inspired Notation
# *einx* - Universal Tensor Operations in Einstein-Inspired Notation

[![pytest](https://github.com/fferflo/einx/actions/workflows/run_pytest.yml/badge.svg)](https://github.com/fferflo/einx/actions/workflows/run_pytest.yml)
[![Documentation](https://img.shields.io/badge/documentation-link-blue.svg)](https://einx.readthedocs.io)
[![PyPI version](https://badge.fury.io/py/einx.svg)](https://badge.fury.io/py/einx)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/)

einx is a Python library that allows formulating many tensor operations as concise expressions using Einstein notation. It is inspired by [einops](https://github.com/arogozhnikov/einops), but follows a novel and unique design:
einx is a Python library that provides a universal interface to formulate tensor operations in frameworks such as Numpy, PyTorch, Jax and Tensorflow. The design is based on the following principles:

- Fully composable and powerful Einstein expressions with `[]`-notation.
- Support for many tensor operations (`einx.{sum|max|where|add|dot|flip|get_at|...}`) with Numpy-like naming.
- Easy integration and mixing with existing code. Supports tensor frameworks Numpy, PyTorch, Tensorflow, Jax and others.
- Just-in-time compilation of all operations into regular Python functions using Python's [`exec()`](https://docs.python.org/3/library/functions.html#exec).
1. **Provide a set of elementary tensor operations** following Numpy-like naming: `einx.{sum|max|where|add|dot|flip|get_at|...}`
2. **Use einx notation to express vectorization of the elementary operations.** einx notation is inspired by [einops](https://github.com/arogozhnikov/einops), but introduces several novel concepts such as `[]`-bracket notation and full composability that allow using it as a universal language for tensor operations.

*Optional:*

- Generalized neural network layers in Einstein notation. Supports PyTorch, Flax, Haiku, Equinox and Keras.
einx can be integrated and mixed with existing code seamlessly. All operations are [just-in-time compiled](https://einx.readthedocs.io/en/latest/more/jit.html) into regular Python functions using Python's [exec()](https://docs.python.org/3/library/functions.html#exec) and invoke operations from the respective framework.

**Getting started:**

* [Tutorial](https://einx.readthedocs.io/en/latest/gettingstarted/einsteinnotation.html)
* [Example: GPT-2/ Mamba with einx](https://einx.readthedocs.io/en/latest/gettingstarted/gpt2.html)
* [How does einx compare with einops?](https://einx.readthedocs.io/en/latest/faq/einops.html)
* [Tutorial](https://einx.readthedocs.io/en/latest/gettingstarted/tutorial_overview.html)
* [Example: GPT-2 with einx](https://einx.readthedocs.io/en/latest/gettingstarted/gpt2.html)
* [How is einx different from einops?](https://einx.readthedocs.io/en/latest/faq/einops.html)
* [How is einx notation universal?](https://einx.readthedocs.io/en/latest/faq/universal.html)
* [API reference](https://einx.readthedocs.io/en/latest/api.html)

## Installation
Expand Down Expand Up @@ -50,8 +47,6 @@ einx.get_at("b [h w] c, b i [2] -> b i c", x, y) # Gather values at coordinates
einx.rearrange("b (q + k) -> b q, b k", x, q=2) # Split
einx.rearrange("b c, 1 -> b (c + 1)", x, [42]) # Append number to each channel

einx.dot("... [c1->c2]", x, y) # Matmul = linear map from c1 to c2 channels

# Apply custom operations:
einx.vmap("b [s...] c -> b c", x, op=np.mean) # Global mean-pooling
einx.vmap("a [b], [b] c -> a c", x, y, op=np.dot) # Matmul
Expand Down Expand Up @@ -84,7 +79,7 @@ einx.dot("b [s...->s2] c", x, w) # - Spatial mixing as in MLP-

See [Common neural network ops](https://einx.readthedocs.io/en/latest/gettingstarted/commonnnops.html) for more examples.

#### Deep learning modules
#### Optional: Deep learning modules

```python
import einx.nn.{torch|flax|haiku|equinox|keras} as einn
Expand All @@ -105,7 +100,7 @@ spatial_dropout = einn.Dropout("[b] ... [c]", drop_rate=0.2)
droppath = einn.Dropout("[b] ...", drop_rate=0.2)
```

See `examples/train_{torch|flax|haiku|equinox|keras}.py` for example trainings on CIFAR10, [GPT-2](https://einx.readthedocs.io/en/latest/gettingstarted/gpt2.html) and [Mamba](https://github.com/fferflo/weightbridge/blob/master/examples/mamba2flax.py) for working example implementations of language models using einx, and [Tutorial: Neural networks](https://einx.readthedocs.io/en/latest/gettingstarted/neuralnetworks.html) for more details.
See `examples/train_{torch|flax|haiku|equinox|keras}.py` for example trainings on CIFAR10, [GPT-2](https://einx.readthedocs.io/en/latest/gettingstarted/gpt2.html) and [Mamba](https://github.com/fferflo/weightbridge/blob/master/examples/mamba2flax.py) for working example implementations of language models using einx, and [Tutorial: Neural networks](https://einx.readthedocs.io/en/latest/gettingstarted/tutorial_neuralnetworks.html) for more details.

#### Just-in-time compilation

Expand All @@ -122,4 +117,4 @@ def op0(i0):
return x1
```

See [Just-in-time compilation](https://einx.readthedocs.io/en/latest/gettingstarted/jit.html) for more details.
See [Just-in-time compilation](https://einx.readthedocs.io/en/latest/more/jit.html) for more details.
16 changes: 5 additions & 11 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,16 @@
einx API
########

Abstractions
============

Main
----

.. autofunction:: einx.rearrange
.. autofunction:: einx.vmap_with_axis
.. autofunction:: einx.vmap
.. autofunction:: einx.dot

Partial specializations
-----------------------

.. autofunction:: einx.reduce
.. autofunction:: einx.elementwise
.. autofunction:: einx.index

Numpy-like functions
====================

Reduction operations
--------------------

Expand Down Expand Up @@ -76,6 +65,11 @@ Miscellaneous operations
.. autofunction:: einx.log_softmax
.. autofunction:: einx.arange

General dot-product
-------------------

.. autofunction:: einx.dot

Deep Learning Modules
=====================

Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

project = "einx"
copyright = "2023, Florian Fervers"
copyright = "2024, Florian Fervers"
author = '<a href="https://fferflo.github.io/">Florian Fervers</a>'

# -- General configuration ---------------------------------------------------
Expand Down
5 changes: 2 additions & 3 deletions docs/source/faq/backend.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
How does einx support different tensor frameworks?
##################################################

einx provides interfaces for tensor frameworks in the ``einx.backend.*`` namespace. For each framework, a backend object is implemented that
provides a numpy-like interface for all necessary tensor operations using the framework's own functions. Every einx function accepts a ``backend`` argument
einx provides interfaces for tensor frameworks in the ``einx.backend.*`` namespace. einx functions accept a ``backend`` argument
that defines which backend to use for the computation. For ``backend=None`` (the default case), the backend is implicitly determined from the input tensors.

.. code:: python
Expand All @@ -22,7 +21,7 @@ Numpy cannot be mixed in the same operation.
einx.dot("a [c1->c2]", x, jnp.asarray(y)) # Uses jax
einx.dot("a [c1->c2]", torch.from_numpy(x), jnp.asarray(y)) # Raises exception
Unkown tensor objects and python sequences are converted using ``np.asarray`` and used as numpy backend tensors.
Unkown tensor objects and python sequences are converted to tensors using calls from the respective backend if possible (e.g. ``np.asarray``, ``torch.asarray``).

.. code:: python
Expand Down
50 changes: 38 additions & 12 deletions docs/source/faq/einops.rst
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
How does einx compare with einops?
How is einx different from einops?
##################################

einx uses Einstein notation that is inspired by and compatible with the notation used in `einops <https://github.com/arogozhnikov/einops>`_,
but follows a novel design:
einx uses Einstein-inspired notation that is based on and compatible with the notation used in `einops <https://github.com/arogozhnikov/einops>`_,
but introduces several novel concepts that allow using it as a universal language for tensor operations:

* Full composability of Einstein expressions: Axis lists, compositions, ellipses and concatenations can be nested arbitrarily (e.g. ``(a b)...`` or
* Introduction of ``[]``-notation to express vectorization of elementary operations (see :ref:`Bracket notation <bracketnotation>`).
* Ellipses repeat the preceding expression rather than an anonymous axis. This allows expressing multi-dimensional operations more concisely
(e.g. ``(a b)...`` or ``b (s [ds])... c``)
* Full composability of expressions: Axis lists, compositions, ellipses, brackets and concatenations can be nested arbitrarily (e.g. ``(a b)...`` or
``b (1 + (s...)) c``).
* Introduction of ``[]``-notation that allows expressing vectorization in an intuitive and concise way, similar to the ``axis`` argument in Numpy functions (see :ref:`Bracket notation <bracketnotation>`).
* Introduction of concatenations as first-class expressions in Einstein notation.
* Introduction of concatenations as first-class expressions.

When combined, these features allow for a concise and expressive formulation of a large variety of tensor operations.
The library provides the following additional features based on the einx notation:

The einx library provides the following additional features:
* Support for many more tensor operations, for example:

.. code::
einx.flip("... (g [c])", x, c=2) # Flip pairs of values
einx.add("a, b -> a b", x, y) # Outer sum
einx.get_at("b [h w] c, b i [2] -> b i c", x, indices) # Gather values
einx.softmax("b q [k] h", attn) # Part of attention operation
* Simpler notation for existing tensor operations:

.. code::
einx.sum("a [b]", x)
# same op as
einops.reduce(x, "a b -> a", reduction="sum")
einx.mean("b (s [ds])... c", x, ds=2)
# einops does not support named ellipses. Alternative for 2D case:
einops.reduce(x, "b (h h2) (w w2) c -> b h w c", reduction="mean", h2=2, w2=2)
* Full support for rearranging expressions in all operations (see :doc:`How does einx handle input and output tensors? </faq/flatten>`).
* ``einx.vmap`` and ``einx.vmap_with_axis`` allow applying arbitrary operations using Einstein notation.
* Specializations provide ease-of-use for main abstractions using Numpy naming convention, e.g. ``einx.sum`` and ``einx.multiply``.
* Several generalized deep learning modules in the ``einx.nn.*`` namespace (see :doc:`Tutorial: Neural networks </gettingstarted/neuralnetworks>`).
* Support for inspecting the backend calls made by einx in index-based notation (see :doc:`Just-in-time compilation </gettingstarted/jit>`).

.. code::
einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=16)
# Axis composition not supported e.g. in einops.einsum.
* ``einx.vmap`` and ``einx.vmap_with_axis`` allow applying arbitrary operations using einx notation.
* Several generalized deep learning modules in the ``einx.nn.*`` namespace (see :doc:`Tutorial: Neural networks </gettingstarted/tutorial_neuralnetworks>`).
* Support for inspecting the backend calls made by einx in index-based notation (see :doc:`Just-in-time compilation </more/jit>`).

A non-exhaustive comparison of operations expressed in einx-notation and einops-notation:

Expand Down
5 changes: 1 addition & 4 deletions docs/source/faq/flatten.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
How does einx handle input and output tensors?
##############################################

einx functions accept an operation string that specifies Einstein expressions for the input and output tensors. The expressions potentially
einx functions accept an operation string that specifies einx expressions for the input and output tensors. The expressions potentially
contain nested compositions and concatenations that prevent the backend functions from directly accessing the required axes. To resolve this, einx
first flattens the input tensors in each operation such that they contain only a flat list of axes. After the backend operation is applied, the
resulting tensors are unflattened to match the requested output expressions.
Expand All @@ -22,9 +22,6 @@ Concatenations are flattened by splitting the input tensor into multiple tensors
# same as
np.split(x, [10], axis=0)
Using a concatenated tensor as input performs the same operation as passing the split tensors as separate inputs to the operation. einx handles
expressions with multiple nested compositions and concatenations gracefully.

After the operation is applied to the flattened tensors, the results are reshaped and concatenated and missing axes are inserted and broadcasted
to match the requested output expressions.

Expand Down
18 changes: 9 additions & 9 deletions docs/source/faq/solver.rst
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
How does einx parse Einstein expressions?
#########################################
How does einx parse expressions?
################################

Overview
--------

einx functions accept a operation string that specifies the shapes of input and output tensors and the requested operation in Einstein notation. For example:
einx functions accept a operation string that specifies the shapes of input and output tensors and the requested operation in einx notation. For example:

.. code::
einx.mean("b (s [r])... c -> b s... c", x, r=4) # Mean-pooling with stride 4
To identify the backend operations that are required to execute this statement, einx first parses the operation string and determines an *Einstein expression tree*
To identify the backend operations that are required to execute this statement, einx first parses the operation string and determines an *expression tree*
for each input and output tensor. The tree represents a full description of the tensor's shape and axes marked with brackets. The nodes represent different types of
subexpressions such as axis lists, compositions, ellipses and concatenations. The leaves of the tree are the named and unnamed axes of the tensor. The expression trees
are used to determine the required rearranging steps and axes along which backend operations are applied.
Expand All @@ -20,12 +20,12 @@ einx uses a multi-step process to convert expression strings into expression tre
* **Stage 0**: Split the operation string into separate expression strings for each tensor.
* **Stage 1**: Parse the expression string for each tensor and return a (stage-1) tree of nodes representing the nested subexpressions.
* **Stage 2**: Expand all ellipses by repeating the respective subexpression, resulting in a stage-2 tree.
* **Stage 3**: Determine a value for each axis (i.e. the axis length) using the provided constraints, resulting in a stage-3 tree, i.e. the final Einstein expression tree.
* **Stage 3**: Determine a value for each axis (i.e. the axis length) using the provided constraints, resulting in a stage-3 tree, i.e. the final expression tree.

For a given operation string and signature of input arguments, the required backend operations are traced into graph representation and just-in-time compiled using Python's
`exec() <https://docs.python.org/3/library/functions.html#exec>`_. Every subsequent call with the same
signature will reuse the cached function and therefore incur no additional overhead other than for cache lookup (see
:doc:`Just-in-time compilation </gettingstarted/jit>`).
:doc:`Just-in-time compilation </more/jit>`).

Stage 0: Splitting the operation string
---------------------------------------
Expand Down Expand Up @@ -53,7 +53,7 @@ Another example of shorthand notation in :func:`einx.dot`:
# same as
einx.dot("a [b->c]", x, y)
See :doc:`Tutorial: Tensor manipulation </gettingstarted/tensormanipulation>` and the documentation of the respective functions for allowed shorthand notation.
See :doc:`Tutorial: Operations </gettingstarted/tutorial_ops>` and the documentation of the respective functions for allowed shorthand notation.

Stage 1: Parsing the expression string
--------------------------------------
Expand All @@ -67,7 +67,7 @@ subexpressions:

Stage-1 tree for ``b (s [r])... c``.

This includes several semantic checks, e.g. to ensure that axis names do not appear more than once per expression.
This includes semantic checks, e.g. to ensure that axis names do not appear more than once per expression.

Stage 2: Expanding ellipses
---------------------------
Expand Down Expand Up @@ -103,7 +103,7 @@ Stage 3: Determining axis values
--------------------------------

In the last step, the values of all axes (i.e. their lengths) are determined using the constraints provided by the input tensors and additional parameters. For example, the above
expression with an input tensor of shape ``(2, 4, 8, 3)`` and additional constraint ``r=4`` results in the following final Einstein expression tree:
expression with an input tensor of shape ``(2, 4, 8, 3)`` and additional constraint ``r=4`` results in the following final expression tree:

.. figure:: /images/stage3-tree.png
:height: 240
Expand Down
Loading

0 comments on commit 73dd98c

Please sign in to comment.