From 8241d18c7224a56a495d306e8c3a2d9f5de529b7 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 30 Nov 2024 18:05:22 +0000 Subject: [PATCH] [nnx] fix nanobind --- .github/workflows/flax_test.yml | 15 +- benchmarks/nnx_graph_overhead.py | 1 + docs_nnx/api_reference/flax.nnx/helpers.rst | 5 +- flax/configurations.py | 11 + flax/nnx/extract.py | 6 +- flax/nnx/graph.py | 164 ++--- flax/nnx/transforms/autodiff.py | 6 +- flax/nnx/transforms/iteration.py | 8 +- flax/nnx/variablelib.py | 2 +- flax/typing.py | 33 + flaxlib_src/CMakeLists.txt | 57 ++ flaxlib_src/Cargo.lock | 295 -------- flaxlib_src/Cargo.toml | 12 - flaxlib_src/meson.build | 14 - flaxlib_src/pyproject.toml | 17 +- flaxlib_src/src/flaxlib.cpp | 659 ++++++++++++++++++ .../src/flaxlib/__init__.py | 20 +- flaxlib_src/src/flaxlib/flaxlib_cpp.pyi | 55 ++ flaxlib_src/src/lib.cc | 14 - flaxlib_src/src/lib.rs | 28 - .../flaxlib.pyi => tests/nnx/flaxlib_test.py | 5 +- tests/nnx/graph_utils_test.py | 48 +- tests/run_all_tests.sh | 1 + uv.lock | 6 +- 24 files changed, 988 insertions(+), 494 deletions(-) create mode 100644 flaxlib_src/CMakeLists.txt delete mode 100644 flaxlib_src/Cargo.lock delete mode 100644 flaxlib_src/Cargo.toml delete mode 100644 flaxlib_src/meson.build create mode 100644 flaxlib_src/src/flaxlib.cpp rename tests/flaxlib_test.py => flaxlib_src/src/flaxlib/__init__.py (53%) create mode 100644 flaxlib_src/src/flaxlib/flaxlib_cpp.pyi delete mode 100644 flaxlib_src/src/lib.cc delete mode 100644 flaxlib_src/src/lib.rs rename flaxlib_src/flaxlib.pyi => tests/nnx/flaxlib_test.py (88%) diff --git a/.github/workflows/flax_test.yml b/.github/workflows/flax_test.yml index 4bed8d8179..8e2b0354b6 100644 --- a/.github/workflows/flax_test.yml +++ b/.github/workflows/flax_test.yml @@ -88,15 +88,23 @@ jobs: python-version: ['3.10', '3.11'] test-type: [doctest, pytest, pytype, mypy] jax-version: [newest] + use-flaxlib: [true, false] exclude: - test-type: pytype python-version: '3.10' - test-type: mypy python-version: '3.11' + - use-flaxlib: true + test-type: doctest + - use-flaxlib: true + test-type: pytype + - use-flaxlib: true + test-type: mypy include: - python-version: '3.10' test-type: pytest jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml + use-flaxlib: false steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -119,12 +127,17 @@ jobs: else uv pip install "jax==${{ matrix.jax-version }}" "jaxlib==${{ matrix.jax-version }}" fi + if [[ "${{ matrix.use-flaxlib }}" == "true" ]]; then + uv pip install "nanobind" "scikit-build-core[pyproject]" + uv pip install -e flaxlib_src + fi - name: Test with ${{ matrix.test-type }} run: | if [[ "${{ matrix.test-type }}" == "doctest" ]]; then uv run tests/run_all_tests.sh --only-doctest elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then - uv run tests/run_all_tests.sh --only-pytest + FLAX_USE_FLAXLIB=${{ matrix.use-flaxlib }} \ + uv run tests/run_all_tests.sh --only-pytest elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then uv run tests/run_all_tests.sh --only-pytype elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then diff --git a/benchmarks/nnx_graph_overhead.py b/benchmarks/nnx_graph_overhead.py index 73cff6d6d6..bd32fc7883 100644 --- a/benchmarks/nnx_graph_overhead.py +++ b/benchmarks/nnx_graph_overhead.py @@ -19,6 +19,7 @@ import optax from time import time + from flax import nnx from absl import flags diff --git a/docs_nnx/api_reference/flax.nnx/helpers.rst b/docs_nnx/api_reference/flax.nnx/helpers.rst index f2b67522d7..7ff94de201 100644 --- a/docs_nnx/api_reference/flax.nnx/helpers.rst +++ b/docs_nnx/api_reference/flax.nnx/helpers.rst @@ -4,10 +4,7 @@ helpers .. automodule:: flax.nnx .. currentmodule:: flax.nnx -.. autoclass:: Dict - :members: -.. autoclass:: List - :members: + .. autoclass:: Sequential :members: .. autoclass:: TrainState diff --git a/flax/configurations.py b/flax/configurations.py index ba19a572fc..5e1a492fcf 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -22,6 +22,7 @@ class Config: + flax_use_flaxlib: bool # See https://google.github.io/pytype/faq.html. _HAS_DYNAMIC_ATTRIBUTES = True @@ -62,6 +63,10 @@ def update(self, name_or_holder, value, /): raise LookupError(f'Unrecognized config option: {name}') self._values[name] = value + def __repr__(self): + values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items()) + return f'Config({values_repr}\n)' + config = Config() @@ -201,3 +206,9 @@ def temp_flip_flag(var_name: str, var_value: bool): ' PRNG keys.' ), ) + +flax_use_flaxlib = bool_flag( + name='flax_use_flaxlib', + default=False, + help='Whether to use flaxlib for C++ acceleration.', +) \ No newline at end of file diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 191a0c195a..48d189b7be 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -67,7 +67,7 @@ def extract_graph_nodes( | tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]] ): """Extracts all graph nodes from a pytree.""" - nodes = graph.RefMap[tp.Any, Index]() + nodes = graph.RefMap[tp.Any, Index]({}) node_prefixes = [] leaves = [] @@ -138,7 +138,7 @@ def check_consistent_aliasing( | None = None, ): if node_prefixes is None: - node_prefixes = graph.RefMap() + node_prefixes = graph.RefMap({}) # collect all paths and prefixes for each node for path, value in graph.iter_graph(node): @@ -324,7 +324,7 @@ def to_tree( assert len(leaf_keys) == len(leaf_prefixes) leaves_out = [] - node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]() + node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]({}) with graph.split_context(ctxtag) as split_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 644099a4f8..db5000f168 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -20,6 +20,7 @@ import threading import typing as tp +from flax import config import jax import numpy as np import typing_extensions as tpe @@ -33,15 +34,14 @@ from flax.nnx.statelib import State from flax.nnx import variablelib from flax.nnx.variablelib import Variable, VariableState -from flax.typing import Key, PathParts, is_key_like +from flax.typing import HashableMapping, Key, PathParts, is_key_like A = tp.TypeVar('A') B = tp.TypeVar('B') C = tp.TypeVar('C') F = tp.TypeVar('F', bound=tp.Callable) -HA = tp.TypeVar('HA', bound=tp.Hashable) -HB = tp.TypeVar('HB', bound=tp.Hashable) + KeyT = tp.TypeVar('KeyT', bound=Key) Index = int @@ -66,9 +66,7 @@ def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]: class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]): """A mapping that uses object id as the hash for the keys.""" - def __init__( - self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), / - ): + def __init__(self, mapping: tp.Mapping[A, B], /): self._mapping: dict[int, tuple[A, B]] = {} self.update(mapping) @@ -90,8 +88,15 @@ def __iter__(self) -> tp.Iterator[A]: def __len__(self) -> int: return len(self._mapping) - def __str__(self) -> str: - return repr(self) +RefIndexMapping = RefMap[tp.Any, Index] +IndexRefMapping = dict[Index, tp.Any] + +if not tp.TYPE_CHECKING: + if config.flax_use_flaxlib: + import flaxlib + + RefIndexMapping = flaxlib.RefIndexMapping + IndexRefMapping = flaxlib.IndexRefMapping @dataclasses.dataclass(frozen=True, slots=True) @@ -200,32 +205,14 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: return GRAPH_REGISTRY[x] -class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): - def __init__(self, mapping: tp.Mapping[HA, HB], copy: bool = True): - self._mapping = dict(mapping) if copy else mapping - def __contains__(self, key: object) -> bool: - return key in self._mapping +IndexMapping = HashableMapping[int, int] - def __getitem__(self, key: HA) -> HB: - return self._mapping[key] +if not tp.TYPE_CHECKING: + if config.flax_use_flaxlib: + import flaxlib - def __iter__(self) -> tp.Iterator[HA]: - return iter(self._mapping) - - def __len__(self) -> int: - return len(self._mapping) - - def __hash__(self) -> int: - return hash(tuple(sorted(self._mapping.items()))) - - def __eq__(self, other: tp.Any) -> bool: - return ( - isinstance(other, HashableMapping) and self._mapping == other._mapping - ) - - def __repr__(self) -> str: - return repr(self._mapping) + IndexMapping = flaxlib.IndexMapping class GraphDef(tp.Generic[Node]): @@ -317,7 +304,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable): index: int attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...] metadata: tp.Any - index_mapping: HashableMapping[Index, Index] | None + index_mapping: IndexMapping | None @classmethod def create( @@ -326,14 +313,14 @@ def create( index: int, attributes: tuple[SubGraphAttribute | StaticAttribute | LeafAttribute, ...], metadata: tp.Any, - index_mapping: tp.Mapping[Index, Index] | None, + index_mapping: IndexMapping | None, ): return cls( type=type, index=index, attributes=attributes, metadata=metadata, - index_mapping=HashableMapping(index_mapping) + index_mapping=IndexMapping(index_mapping) # type: ignore if index_mapping is not None else None, ) @@ -388,7 +375,7 @@ def _apply( def flatten( - node: Node, /, ref_index: RefMap[tp.Any, Index] | None = None + node: Node, /, ref_index: RefIndexMapping | None = None ) -> tuple[GraphDef[Node], GraphState]: """Flattens a graph node into a (graphdef, state) pair. @@ -399,15 +386,22 @@ def flatten( nodes that share references. """ if ref_index is None: - ref_index = RefMap() - flat_state: list[tuple[PathParts, StateLeaf]] = [] - graphdef = _graph_flatten((), ref_index, flat_state, node) + ref_index = RefIndexMapping({}) + + flat_state: list[tuple[PathParts, StateLeaf]] + if config.flax_use_flaxlib: + import flaxlib # type: ignore + + graphdef, flat_state = flaxlib._graph_flatten_top(ref_index, node) + else: + flat_state = [] + graphdef = _graph_flatten([], ref_index, flat_state, node) return graphdef, GraphState.from_flat_path(flat_state) def _graph_flatten( - path: PathParts, - ref_index: RefMap[tp.Any, Index], + path: list[Key], + ref_index: RefIndexMapping, flat_state: list[tuple[PathParts, StateLeaf]], node: Node, ) -> NodeDef[Node] | NodeRef: @@ -430,8 +424,9 @@ def _graph_flatten( values, metadata = node_impl.flatten(node) for key, value in values: + path.append(key) if is_node(value): - nodedef = _graph_flatten((*path, key), ref_index, flat_state, value) + nodedef = _graph_flatten(path, ref_index, flat_state, value) # subgraphs.append((key, nodedef)) attributes.append(SubGraphAttribute(key, nodedef)) elif isinstance(value, Variable): @@ -440,7 +435,7 @@ def _graph_flatten( LeafAttribute(key, NodeRef(type(value), ref_index[value])) ) else: - flat_state.append(((*path, key), value.to_state())) + flat_state.append((tuple(path), value.to_state())) variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( type(value), variable_index, HashableMapping(value._var_metadata) @@ -448,12 +443,13 @@ def _graph_flatten( attributes.append(LeafAttribute(key, variabledef)) else: if isinstance(value, (jax.Array, np.ndarray)): - path_str = '/'.join(map(str, (*path, key))) + path_str = '/'.join(map(str, path)) raise ValueError( f'Arrays leaves are not supported, at {path_str!r}: {value}' ) # static_fields.append((key, value)) attributes.append(StaticAttribute(key, value)) + path.pop() nodedef = NodeDef.create( type=node_impl.type, @@ -464,14 +460,20 @@ def _graph_flatten( ) return nodedef +if not tp.TYPE_CHECKING: + if config.flax_use_flaxlib: + import flaxlib + + _graph_flatten = flaxlib._graph_flatten + def unflatten( graphdef: GraphDef[Node], state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], /, *, - index_ref: dict[Index, tp.Any] | None = None, - index_ref_cache: dict[Index, tp.Any] | None = None, + index_ref: IndexRefMapping | None = None, + index_ref_cache: IndexRefMapping | None = None, ) -> Node: """Unflattens a graphdef into a node with the given state. @@ -491,7 +493,7 @@ def unflatten( if isinstance(state, State): state = state.raw_mapping # type: ignore if index_ref is None: - index_ref = {} + index_ref = IndexRefMapping({}) assert isinstance(graphdef, (NodeDef, NodeRef)) node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache) return node @@ -499,8 +501,8 @@ def unflatten( def _graph_unflatten( nodedef: NodeDef[Node] | NodeRef[Node], state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]], - index_ref: dict[Index, tp.Any], - index_ref_cache: dict[Index, tp.Any] | None, + index_ref: IndexRefMapping, + index_ref_cache: IndexRefMapping | None, ) -> Node: """Recursive helper for graph_unflatten. @@ -788,7 +790,7 @@ class GraphContext(threading.local): @dataclasses.dataclass class SplitContext: ctxtag: str | None - ref_index: RefMap[tp.Any, Index] + ref_index: RefIndexMapping @tp.overload def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ... @@ -815,17 +817,15 @@ def split( states = _split_state(state, filters) if ctx is not None: if ctx.index_ref is not None and isinstance(graphdef, NodeDef): - index_to_index = compose_mapping(ctx.index_ref, self.ref_index) - graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, copy=False) - ) + index_to_index = create_index_mapping(ctx.index_ref, self.ref_index) + graphdef = dataclasses.replace(graphdef, index_mapping=index_to_index) return graphdef, *states @contextlib.contextmanager def split_context(ctxtag: str | None = None): - index_ref: RefMap[tp.Any, Index] = RefMap() + index_ref = RefIndexMapping({}) flatten_ctx = SplitContext(ctxtag, index_ref) GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx) @@ -843,7 +843,7 @@ def split_context(ctxtag: str | None = None): @dataclasses.dataclass class MergeContext: ctxtag: str | None - index_ref: dict[Index, tp.Any] + index_ref: IndexRefMapping def merge( self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState @@ -858,9 +858,7 @@ def merge( ): # outer merge (4), create index_ref_cache assert ctx.ref_index is not None - index_ref_cache = compose_mapping_reversed( - ctx.ref_index, graphdef.index_mapping - ) + index_ref_cache = create_index_ref(ctx.ref_index, graphdef.index_mapping) else: # inner merge (2) index_ref_cache = None @@ -877,7 +875,7 @@ def merge( @contextlib.contextmanager def merge_context(ctxtag: str | None = None): - index_ref: dict[Index, tp.Any] = {} + index_ref = IndexRefMapping({}) unflatten_ctx = MergeContext(ctxtag, index_ref) GRAPH_CONTEXT.index_ref_stack.append(unflatten_ctx) @@ -898,7 +896,7 @@ class UpdateContext: """A context manager for handling complex state updates.""" tag: str - ref_index: RefMap[tp.Any, Index] | None + ref_index: RefIndexMapping | None index_ref: dict[Index, tp.Any] | None # define hash and eq to make this an opaque object @@ -908,7 +906,7 @@ def __hash__(self): def __eq__(self, other): return isinstance(other, UpdateContext) - def flatten_end(self, ref_index: RefMap[tp.Any, Index]): + def flatten_end(self, ref_index: RefIndexMapping): if self.ref_index is None: # outer split (1), store the references self.ref_index = ref_index @@ -1000,15 +998,13 @@ def split( :class:`GraphDef` and one or more :class:`State`'s equal to the number of filters passed. If no filters are passed, a single :class:`State` is returned. """ - ref_index: RefMap[tp.Any, Index] = RefMap() + ref_index = RefIndexMapping({}) graphdef, state = flatten(node, ref_index) states = _split_state(state, filters) if self.index_ref is not None and isinstance(graphdef, NodeDef): - index_to_index = compose_mapping(self.index_ref, ref_index) - graphdef = dataclasses.replace( - graphdef, index_mapping=HashableMapping(index_to_index, copy=False) - ) + index_to_index = create_index_mapping(self.index_ref, ref_index) + graphdef = dataclasses.replace(graphdef, index_mapping=index_to_index) self.flatten_end(ref_index) @@ -1031,15 +1027,13 @@ def merge( if graphdef.index_mapping is not None: # outer merge (4), create index_ref_cache assert self.ref_index is not None - index_ref_cache = compose_mapping_reversed( - self.ref_index, graphdef.index_mapping - ) + index_ref_cache = create_index_ref(self.ref_index, graphdef.index_mapping) else: # inner merge (2) index_ref_cache = None state = State.merge(state, *states) - index_ref: dict[Index, tp.Any] = {} + index_ref = IndexRefMapping({}) node = unflatten( graphdef, state, index_ref=index_ref, index_ref_cache=index_ref_cache ) @@ -1751,16 +1745,30 @@ def _iter_graph( yield path_parts, node -def compose_mapping( - map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], / -) -> dict[A, C]: - return {a: map_bc[b] for a, b in map_ab.items() if b in map_bc} +def create_index_mapping( + index_ref: IndexRefMapping, ref_index: RefIndexMapping, / +) -> IndexMapping: + return IndexMapping( + {a: ref_index[b] for a, b in index_ref.items() if b in ref_index}, + copy=True, + ) + + +def create_index_ref( + ref_index: RefIndexMapping, index_mapping: IndexMapping, / +) -> IndexRefMapping: + return { + index_mapping[index]: ref + for ref, index in ref_index.items() + if index in index_mapping + } + +if not tp.TYPE_CHECKING: + if config.flax_use_flaxlib: + import flaxlib -def compose_mapping_reversed( - map_ab: tp.Mapping[A, B], map_bc: tp.Mapping[B, C], / -) -> dict[C, A]: - return {map_bc[b]: a for a, b in map_ab.items() if b in map_bc} + create_index_ref = flaxlib.create_index_ref @dataclasses.dataclass(frozen=True) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 5ef0d183b7..d200955be9 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -427,7 +427,7 @@ def _custom_vjp_split_fn( nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) -def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]): +def _extract_index_mappings(x, *, index_mappings: deque[graph.IndexMapping]): if isinstance(x, graph.NodeDef): assert x.index_mapping is not None index_mappings.append(x.index_mapping) @@ -465,7 +465,7 @@ def __call__(self, *pure_args): (args_out, out), ctxtag=self.ctxtag ) # remove index_mapping from NodeDef's but store them in global context - index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state( + index_mappings: deque[graph.IndexMapping] = extract.get_broadcast_state( self.ctxtag ) @@ -664,7 +664,7 @@ def __call__( # insert index_mappings def _insert_index_mappings(x): if isinstance(x, graph.NodeDef): - index_mapping: graph.HashableMapping = index_mappings.popleft() + index_mapping: tp.Mapping = index_mappings.popleft() return dataclasses.replace(x, index_mapping=index_mapping) return x diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 994e582862..903ca668d1 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -650,7 +650,7 @@ def check_carry_same_references(key_path, arg, out): def _extract_index_mappings( pure_carry_arg_out, - carry_index_mappings: list[graph.HashableMapping[int, int]], + carry_index_mappings: list[graph.IndexMapping], /, ): def extract_index_mappings(x): @@ -675,7 +675,7 @@ def extract_index_mappings(x): def _insert_index_mappings( pure_carry_arg_out, - carry_index_mappings: deque[graph.HashableMapping[int, int]], + carry_index_mappings: deque[graph.IndexMapping], /, ): def insert_index_mappings(x): @@ -1096,7 +1096,7 @@ def __call__( # next we have to remove all the index_mappings from the NodeDefs # in the carry outputs because they are not present in the inputs - carry_index_mappings: list[graph.HashableMapping[int, int]] = [] + carry_index_mappings: list[graph.IndexMapping] = [] pure_carry_arg_out = _extract_index_mappings( pure_carry_arg_out, carry_index_mappings ) @@ -1357,7 +1357,7 @@ def per_node_def(nd: graph.NodeDef | graph.NodeRef): return dataclasses.replace( ns, _graphdef=dataclasses.replace( - ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping) + ns._graphdef, index_mapping=graph.IndexMapping(global_index_mapping) ), ) diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 4752a9b7bd..4ed854b0a8 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -783,7 +783,7 @@ def __treescope_repr__(self, path, subtree_renderer): ) def replace(self, value: B) -> VariableState[B]: - return VariableState(self.type, value, **self.get_metadata()) + return VariableState(self.type, value, **self._var_metadata) def to_variable(self) -> Variable[A]: # we use object.__new__ to avoid calling __init__ and bypass the diff --git a/flax/typing.py b/flax/typing.py index a630a3571e..e1419955b2 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -23,6 +23,7 @@ TypeVar, Union, ) +from collections.abc import Iterator from collections.abc import Callable, Hashable, Mapping, Sequence import jax @@ -161,3 +162,35 @@ class Missing: MISSING = Missing() +HA = TypeVar('HA', bound=Hashable) +HB = TypeVar('HB', bound=Hashable) + + +class HashableMapping(Mapping[HA, HB], Hashable): + def __init__(self, mapping: Mapping[HA, HB], copy: bool = True): + self._mapping = ( + dict(mapping) if copy or not isinstance(mapping, dict) else mapping + ) + + def __contains__(self, key: object) -> bool: + return key in self._mapping + + def __getitem__(self, key: HA) -> HB: + return self._mapping[key] + + def __iter__(self) -> Iterator[HA]: + return iter(self._mapping) + + def __len__(self) -> int: + return len(self._mapping) + + def __hash__(self) -> int: + return hash(tuple(sorted(self._mapping.items()))) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, HashableMapping) and self._mapping == other._mapping + ) + + def __repr__(self) -> str: + return repr(self._mapping) \ No newline at end of file diff --git a/flaxlib_src/CMakeLists.txt b/flaxlib_src/CMakeLists.txt new file mode 100644 index 0000000000..28b2b8cf36 --- /dev/null +++ b/flaxlib_src/CMakeLists.txt @@ -0,0 +1,57 @@ +# Set the minimum CMake version and policies for highest tested version +cmake_minimum_required(VERSION 3.15...3.27) + +# Set up the project and ensure there is a working C++ compiler +project(flaxlib LANGUAGES CXX) + +# Warn if the user invokes CMake directly +if (NOT SKBUILD) + message(WARNING "\ + This CMake file is meant to be executed using 'scikit-build-core'. + Running it directly will almost certainly not produce the desired + result. If you are a user trying to install this package, use the + command below, which will install all necessary build dependencies, + compile the package in an isolated environment, and then install it. + ===================================================================== + $ pip install . + ===================================================================== + If you are a software developer, and this is your own package, then + it is usually much more efficient to install the build dependencies + in your environment once and use the following command that avoids + a costly creation of a new virtual environment at every compilation: + ===================================================================== + $ pip install nanobind scikit-build-core[pyproject] + $ pip install --no-build-isolation -ve . + ===================================================================== + You may optionally add -Ceditable.rebuild=true to auto-rebuild when + the package is imported. Otherwise, you need to rerun the above + after editing C++ files.") +endif() + +# Try to import all Python components potentially needed by nanobind +find_package(Python 3.8 + REQUIRED COMPONENTS Interpreter Development.Module + OPTIONAL_COMPONENTS Development.SABIModule) + +# Import nanobind through CMake's find_package mechanism +find_package(nanobind CONFIG REQUIRED) +find_package(OpenSSL REQUIRED) + +# We are now ready to compile the actual extension module +nanobind_add_module( + # Name of the extension + flaxlib_cpp + + # Target the stable ABI for Python 3.12+, which reduces + # the number of binary wheels that must be built. This + # does nothing on older Python versions + STABLE_ABI + + # Source code goes here + src/flaxlib.cpp +) + +target_link_libraries(flaxlib_cpp PRIVATE OpenSSL::SSL OpenSSL::Crypto) + +# Install directive for scikit-build-core +install(TARGETS flaxlib_cpp LIBRARY DESTINATION flaxlib) \ No newline at end of file diff --git a/flaxlib_src/Cargo.lock b/flaxlib_src/Cargo.lock deleted file mode 100644 index 6a6decf96f..0000000000 --- a/flaxlib_src/Cargo.lock +++ /dev/null @@ -1,295 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "autocfg" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" - -[[package]] -name = "bitflags" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "flaxlib" -version = "0.0.1-a1" -dependencies = [ - "pyo3", -] - -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - -[[package]] -name = "indoc" -version = "2.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" - -[[package]] -name = "libc" -version = "0.2.158" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" - -[[package]] -name = "lock_api" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" -dependencies = [ - "autocfg", - "scopeguard", -] - -[[package]] -name = "memoffset" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ea5043e58958ee56f3e15a90aee535795cd7dfd319846288d93c5b57d85cbe" - -[[package]] -name = "parking_lot" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" -dependencies = [ - "lock_api", - "parking_lot_core", -] - -[[package]] -name = "parking_lot_core" -version = "0.9.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" -dependencies = [ - "cfg-if", - "libc", - "redox_syscall", - "smallvec", - "windows-targets", -] - -[[package]] -name = "portable-atomic" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" - -[[package]] -name = "proc-macro2" -version = "1.0.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "pyo3" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "parking_lot", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" -dependencies = [ - "heck", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn", -] - -[[package]] -name = "quote" -version = "1.0.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "redox_syscall" -version = "0.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" -dependencies = [ - "bitflags", -] - -[[package]] -name = "scopeguard" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" - -[[package]] -name = "smallvec" -version = "1.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" - -[[package]] -name = "syn" -version = "2.0.77" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "target-lexicon" -version = "0.12.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" - -[[package]] -name = "unicode-ident" -version = "1.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" - -[[package]] -name = "unindent" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" - -[[package]] -name = "windows-targets" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/flaxlib_src/Cargo.toml b/flaxlib_src/Cargo.toml deleted file mode 100644 index 80e9515239..0000000000 --- a/flaxlib_src/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "flaxlib" -version = "0.0.1-a1" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[lib] -name = "flaxlib" -crate-type = ["cdylib"] - -[dependencies] -pyo3 = "0.21.2" diff --git a/flaxlib_src/meson.build b/flaxlib_src/meson.build deleted file mode 100644 index 0d78d9436b..0000000000 --- a/flaxlib_src/meson.build +++ /dev/null @@ -1,14 +0,0 @@ -project( - 'flaxlib', - 'cpp', - version: '0.0.1', - default_options: ['cpp_std=c++17'], -) -py = import('python').find_installation() -nanobind_dep = dependency('nanobind', static: true) -py.extension_module( - 'flaxlib', - sources: ['src/lib.cc'], - dependencies: [nanobind_dep], - install: true, -) \ No newline at end of file diff --git a/flaxlib_src/pyproject.toml b/flaxlib_src/pyproject.toml index 0afc7699a5..fd6c0b61b4 100644 --- a/flaxlib_src/pyproject.toml +++ b/flaxlib_src/pyproject.toml @@ -1,17 +1,28 @@ [build-system] -requires = ['meson-python'] -build-backend = 'mesonpy' +requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"] +build-backend = "scikit_build_core.build" [project] name = "flaxlib" +version = "0.0.1" requires-python = ">=3.10" classifiers = [ "Programming Language :: C++", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dynamic = ["version"] + [project.optional-dependencies] tests = [ "pytest", ] + +[tool.scikit-build] +# Protect the configuration against future changes in scikit-build-core +minimum-version = "0.4" + +# Setuptools-style build caching in a local directory +build-dir = "build/{wheel_tag}" + +# Build stable ABI wheels for CPython 3.12+ +wheel.py-api = "cp312" \ No newline at end of file diff --git a/flaxlib_src/src/flaxlib.cpp b/flaxlib_src/src/flaxlib.cpp new file mode 100644 index 0000000000..8ab917cad7 --- /dev/null +++ b/flaxlib_src/src/flaxlib.cpp @@ -0,0 +1,659 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nb = nanobind; +using namespace nb::literals; + +namespace flaxlib +{ + // ----------------------------------- + // helper functions + // ----------------------------------- + intptr_t get_id(nb::object obj) + { + // Get the object ID + return reinterpret_cast(obj.ptr()); + } + + bool nb_isinstance(nanobind::handle inst, nanobind::handle cls) + { + int ret = PyObject_IsInstance(inst.ptr(), cls.ptr()); + if (ret == -1) + { + throw nb::python_error(); + } + return ret; + } + + nb::object vector_to_tuple(const std::vector &vec) + { + + if (vec.empty()) + { + return nb::tuple(); + } + else + { + auto ls = nb::list(); + for (const auto &item : vec) + { + ls.append(item); + } + auto result = nb::tuple(ls); + return result; + } + } + + // ----------------------------------- + // IndexMapping + // ----------------------------------- + class IndexMappingKeysIterator + { + public: + IndexMappingKeysIterator(const std::unordered_map &data) : it(data.begin()), end(data.end()) {} + + int next() + { + if (it == end) + { + throw nb::stop_iteration(); + } + + return it++->first; + } + + IndexMappingKeysIterator &__iter__() + { + return *this; + } + + private: + std::unordered_map::const_iterator it; + std::unordered_map::const_iterator end; + }; + + struct IndexMapping + { + std::unordered_map mapping; + + IndexMapping(std::unordered_map &mapping, bool copy) + { + if (copy) + { + this->mapping = mapping; + } + else + { + this->mapping = std::move(mapping); + } + } + + // define the python __hash__ method + uint64_t __hash__() + { + EVP_MD_CTX *mdctx; + const EVP_MD *md; + unsigned char md_value[EVP_MAX_MD_SIZE]; + unsigned int md_len; + + // Serialize the map + std::stringstream ss; + for (const auto &pair : mapping) + { + ss << pair.first << ":" << pair.second << ","; + } + std::string serializedData = ss.str(); + + OpenSSL_add_all_digests(); + + md = EVP_get_digestbyname("SHA256"); + if (!md) + { + throw std::runtime_error("Unknown message digest BLAKE3"); + } + + mdctx = EVP_MD_CTX_new(); + EVP_DigestInit_ex(mdctx, md, NULL); + EVP_DigestUpdate(mdctx, serializedData.c_str(), serializedData.size()); + EVP_DigestFinal_ex(mdctx, md_value, &md_len); + EVP_MD_CTX_free(mdctx); + + // Convert (part of) the digest to a 64-bit integer + uint64_t result = 0; + for (size_t i = 0; i < 8 && i < md_len; ++i) + { + result = (result << 8) | static_cast(md_value[i]); + } + + return result; + } + + // define the python __repr__ method + std::string __repr__() + { + std::string repr; + if (mapping.size() == 1) + { + repr = "IndexMapping({"; + for (const auto &pair : mapping) + { + repr += std::to_string(pair.first) + ": " + std::to_string(pair.second); + } + repr += "})"; + } + else + { + repr = "IndexMapping({\n"; + for (const auto &pair : mapping) + { + repr += " " + std::to_string(pair.first) + ": " + std::to_string(pair.second) + ",\n"; + } + if (!mapping.empty()) + { + repr.pop_back(); + repr.pop_back(); + } + repr += "\n})"; + } + return repr; + } + + // define the python __getitem__ method + int __getitem__(int key) const + { + return mapping.at(key); + } + + // define __iter__ method + IndexMappingKeysIterator __iter__() const + { + return IndexMappingKeysIterator(mapping); + } + + // define the python __len__ method + size_t __len__() const + { + return mapping.size(); + } + + // define the python __contains__ method + bool __contains__(int key) const + { + return mapping.find(key) != mapping.end(); + } + + bool __eq__(const nb::object &other) const + { + if (!nb::isinstance(other)) + { + return false; + } + + auto other_mapping = nb::cast(other); + return mapping == other_mapping.mapping; + } + + nb::object items() const + { + return nb::make_iterator( + nb::type>>(), "IndexMappingItemsIterator", mapping.begin(), mapping.end()); + } + }; + + // ----------------------------------- + // RefIndexMapping + // ----------------------------------- + + struct RefIndexMappingKeysIterator + { + public: + RefIndexMappingKeysIterator(const std::unordered_map> &data) : it(data.begin()), end(data.end()) {} + + nb::object next() + { + if (it == end) + { + throw nb::stop_iteration(); + } + + return it++->second.first; + } + + RefIndexMappingKeysIterator &__iter__() + { + return *this; + } + + private: + std::unordered_map>::const_iterator it; + std::unordered_map>::const_iterator end; + }; + + struct RefIndexMappingItemsIterator + { + public: + RefIndexMappingItemsIterator(const std::unordered_map> &data) : it(data.begin()), end(data.end()) {} + + std::pair next() + { + if (it == end) + { + throw nb::stop_iteration(); + } + + return it++->second; + } + + RefIndexMappingItemsIterator &__iter__() + { + return *this; + } + + private: + std::unordered_map>::const_iterator it; + std::unordered_map>::const_iterator end; + }; + + struct RefIndexMapping + { + std::unordered_map> mapping; + + RefIndexMapping(std::map ref_mapping) + { + for (const auto &pair : ref_mapping) + { + mapping[get_id(pair.first)] = {pair.first, pair.second}; + } + } + + int __getitem__(nb::object key) const + { + return mapping.at(get_id(key)).second; + } + + bool __contains__(nb::object key) const + { + return mapping.find(get_id(key)) != mapping.end(); + } + + void __setitem__(nb::object key, int value) + { + mapping[get_id(key)] = {key, value}; + } + + void __delitem__(nb::object key) + { + mapping.erase(get_id(key)); + } + + RefIndexMappingKeysIterator __iter__() const + { + return RefIndexMappingKeysIterator(mapping); + } + + size_t __len__() const + { + return mapping.size(); + } + + // __repr__ method + std::string __repr__() + { + std::string repr; + if (mapping.size() == 1) + { + repr = "RefIndexMapping({"; + for (const auto &pair : mapping) + { + repr += nb::cast(nb::repr(pair.second.first)) + ": " + std::to_string(pair.second.second); + } + repr += "})"; + } + else + { + repr = "RefIndexMapping({\n"; + for (const auto &pair : mapping) + { + repr += " " + nb::cast(nb::repr(pair.second.first)) + ": " + std::to_string(pair.second.second) + ",\n"; + } + if (!mapping.empty()) + { + repr.pop_back(); + repr.pop_back(); + } + repr += "\n})"; + } + return repr; + } + + RefIndexMappingItemsIterator items() const + { + return RefIndexMappingItemsIterator(mapping); + } + }; + + // ------------------------------------- + // IndexRefMapping + // ------------------------------------- + + struct IndexRefMappingKeysIterator + { + public: + IndexRefMappingKeysIterator(const std::unordered_map &data) : it(data.begin()), end(data.end()) {} + + int next() + { + if (it == end) + { + throw nb::stop_iteration(); + } + + return get_id(it++->second); + } + + IndexRefMappingKeysIterator &__iter__() + { + return *this; + } + + private: + std::unordered_map::const_iterator it; + std::unordered_map::const_iterator end; + }; + + struct IndexRefMapping + { + std::unordered_map mapping; + + IndexRefMapping(std::unordered_map mapping) : mapping(mapping) {} + + nb::object __getitem__(int key) const + { + return mapping.at(key); + } + + bool __contains__(int key) const + { + return mapping.find(key) != mapping.end(); + } + + void __setitem__(int key, nb::object value) + { + mapping[key] = value; + } + + void __delitem__(int key) + { + mapping.erase(key); + } + + IndexRefMappingKeysIterator __iter__() const + { + return IndexRefMappingKeysIterator(mapping); + } + + size_t __len__() const + { + return mapping.size(); + } + + std::string __repr__() + { + std::string repr; + if (mapping.size() <= 1) + { + repr = "IndexRefMapping({"; + for (const auto &pair : mapping) + { + repr += std::to_string(pair.first) + ": " + nb::cast(nb::repr(pair.second)); + } + repr += "})"; + } + else + { + repr = "IndexRefMapping({\n"; + for (const auto &pair : mapping) + { + repr += " " + std::to_string(pair.first) + ": " + nb::cast(nb::repr(pair.second)) + ",\n"; + } + if (!mapping.empty()) + { + repr.pop_back(); + repr.pop_back(); + } + repr += "\n})"; + } + return repr; + } + + nb::object items() const + { + return nb::make_iterator(nb::type>>(), "IndexRefMappingItemsIterator", mapping.begin(), mapping.end()); + } + }; + + // ------------------------------------- + // functions + // ------------------------------------- + + IndexRefMapping create_index_ref(RefIndexMapping ref_index, IndexMapping index_mapping) + { + std::unordered_map new_mapping; + for (const auto &pair : ref_index.mapping) + { + auto a = pair.second.first; + auto b = pair.second.second; + + auto b_pos = index_mapping.mapping.find(b); + if (b_pos != index_mapping.mapping.end()) + { + new_mapping[b_pos->second] = a; + } + } + return IndexRefMapping(new_mapping); + } + + nb::object _graph_flatten( + std::vector &path, + RefIndexMapping &ref_index, + std::vector> &flat_state, + nb::object node) + { + // import graph Module from flax.nnx + auto graph = nb::module_::import_("flax.nnx.graph"); + auto jax = nb::module_::import_("jax"); + auto np = nb::module_::import_("numpy"); + + auto jax_Array = nb::getattr(jax, "Array"); + auto np_ndarray = nb::getattr(np, "ndarray"); + auto GraphNodeImpl = nb::getattr(graph, "GraphNodeImpl"); + auto Variable = nb::getattr(graph, "Variable"); + auto SubGraphAttribute = nb::getattr(graph, "SubGraphAttribute"); + auto StaticAttribute = nb::getattr(graph, "StaticAttribute"); + auto LeafAttribute = nb::getattr(graph, "LeafAttribute"); + auto NodeRef = nb::getattr(graph, "NodeRef"); + auto NodeDef = nb::getattr(graph, "NodeDef"); + auto VariableDef = nb::getattr(graph, "VariableDef"); + auto HashableMapping = nb::getattr(graph, "HashableMapping"); + + if (!nb::bool_(nb::getattr(graph, "is_node")(node))) + { + throw std::runtime_error("Unsupported type: " + nb::cast(node.type().attr("__name__")) + ", this is a bug."); + } + + if (ref_index.__contains__(node)) + { + return NodeRef(node.type(), ref_index.__getitem__(node)); + } + + auto node_impl = nb::getattr(graph, "get_node_impl")(node); + + int index; + // only cache graph nodes + if (nb_isinstance(node_impl, GraphNodeImpl)) + { + index = ref_index.__len__(); + ref_index.__setitem__(node, index); + } + else + { + index = -1; + } + + std::vector attributes; + + auto values_metadata = nb::getattr(node_impl, "flatten")(node); + auto values = values_metadata[0]; + auto metadata = values_metadata[1]; + + for (const auto &key_value : values) + { + auto key = key_value[0]; + auto value = key_value[1]; + + path.push_back(key); + + if (nb::bool_(nb::getattr(graph, "is_node")(value))) + { + auto nodedef = _graph_flatten(path, ref_index, flat_state, value); + attributes.push_back(SubGraphAttribute(key, nodedef)); + } + else if (nb_isinstance(value, Variable)) + { + if (ref_index.__contains__(value)) + { + attributes.push_back(LeafAttribute(key, NodeRef(value.type(), ref_index.__getitem__(value)))); + } + else + { + auto path_tuple = vector_to_tuple(path); + flat_state.push_back({path_tuple, nb::getattr(value, "to_state")()}); + auto variable_index = ref_index.__len__(); + ref_index.__setitem__(value, variable_index); + auto var_meta = HashableMapping(nb::getattr(value, "_var_metadata")); + auto variabledef = VariableDef(value.type(), variable_index, var_meta); + attributes.push_back(LeafAttribute(key, variabledef)); + } + } + else + { + if (nb_isinstance(value, jax_Array) || nb_isinstance(value, np_ndarray)) + { + std::string path_str; + for (const auto &part : path) + { + path_str += nb::cast(nb::repr(part)) + "/"; + } + throw std::runtime_error("Arrays leaves are not supported, at " + path_str + ": " + nb::cast(nb::repr(value))); + } + attributes.push_back(StaticAttribute(key, value)); + } + path.pop_back(); + } + + auto attributes_tuple = vector_to_tuple(attributes); + auto nodedef = nb::getattr(NodeDef, "create")( + nb::getattr(node_impl, "type"), index, attributes_tuple, metadata, nb::none()); + + return nodedef; + } + + std::pair _graph_flatten_top( + RefIndexMapping &ref_index, + nb::object node) + { + // print "here" + std::vector path = {}; + std::vector> flat_state = {}; + auto nodedef = _graph_flatten(path, ref_index, flat_state, node); + + auto flat_state_out = nb::list(); + for (const auto &pair : flat_state) + { + flat_state_out.append(nb::make_tuple(pair.first, pair.second)); + } + return {nodedef, flat_state_out}; + } + + NB_MODULE(flaxlib_cpp, m) + { + //------------------------------------------------------------------------- + // IndexMapping + //------------------------------------------------------------------------- + nb::class_(m, "IndexMapping") + .def(nb::init &, bool>(), nb::arg("mapping"), nb::arg("copy") = true) + .def("__hash__", &IndexMapping::__hash__) + .def("__repr__", &IndexMapping::__repr__) + .def("__getitem__", &IndexMapping::__getitem__) + .def("__iter__", &IndexMapping::__iter__) + .def("__len__", &IndexMapping::__len__) + .def("__contains__", &IndexMapping::__contains__, nb::arg("key").none()) + .def("__eq__", &IndexMapping::__eq__) + .def("items", &IndexMapping::items); + + nb::class_(m, "IndexMappingIterator") + .def("__iter__", &IndexMappingKeysIterator::__iter__) + .def("__next__", &IndexMappingKeysIterator::next); + + //------------------------------------------------------------------------- + // RefIndexMapping + //------------------------------------------------------------------------- + nb::class_(m, "RefIndexMapping") + .def(nb::init>()) + .def("__getitem__", &RefIndexMapping::__getitem__) + .def("__contains__", &RefIndexMapping::__contains__, nb::arg("key").none()) + .def("__setitem__", &RefIndexMapping::__setitem__) + .def("__delitem__", &RefIndexMapping::__delitem__) + .def("__iter__", &RefIndexMapping::__iter__) + .def("__len__", &RefIndexMapping::__len__) + .def("__repr__", &RefIndexMapping::__repr__) + .def("items", &RefIndexMapping::items); + + nb::class_(m, "RefIndexMappingKeysIterator") + .def("__iter__", &RefIndexMappingKeysIterator::__iter__) + .def("__next__", &RefIndexMappingKeysIterator::next); + + nb::class_(m, "RefIndexMappingItemsIterator") + .def("__iter__", &RefIndexMappingItemsIterator::__iter__) + .def("__next__", &RefIndexMappingItemsIterator::next); + + //------------------------------------------------------------------------- + // IndexRefMapping + //------------------------------------------------------------------------- + nb::class_(m, "IndexRefMapping") + .def(nb::init &>()) + .def("__getitem__", &IndexRefMapping::__getitem__) + .def("__contains__", &IndexRefMapping::__contains__, nb::arg("key").none()) + .def("__setitem__", &IndexRefMapping::__setitem__) + .def("__delitem__", &IndexRefMapping::__delitem__) + .def("__iter__", &IndexRefMapping::__iter__) + .def("__len__", &IndexRefMapping::__len__) + .def("__repr__", &IndexRefMapping::__repr__) + .def("items", &IndexRefMapping::items); + + nb::class_(m, "IndexRefMappingKeysIterator") + .def("__iter__", &IndexRefMappingKeysIterator::__iter__) + .def("__next__", &IndexRefMappingKeysIterator::next); + + //------------------------------------------------------------------------- + // functions + //------------------------------------------------------------------------- + m.def("create_index_ref", &create_index_ref); + m.def("_graph_flatten_top", &_graph_flatten_top); + m.def("_graph_flatten", &_graph_flatten); + } + +} // namespace flaxlib \ No newline at end of file diff --git a/tests/flaxlib_test.py b/flaxlib_src/src/flaxlib/__init__.py similarity index 53% rename from tests/flaxlib_test.py rename to flaxlib_src/src/flaxlib/__init__.py index c23f70baa7..bc194a4055 100644 --- a/tests/flaxlib_test.py +++ b/flaxlib_src/src/flaxlib/__init__.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .flaxlib_cpp import IndexMapping as IndexMapping +from .flaxlib_cpp import RefIndexMapping as RefIndexMapping +from .flaxlib_cpp import IndexRefMapping as IndexRefMapping +from .flaxlib_cpp import create_index_ref as create_index_ref +from .flaxlib_cpp import _graph_flatten as _graph_flatten +from .flaxlib_cpp import _graph_flatten_top as _graph_flatten_top -# TODO: Re-enable this test after setting up CI build for flaxlib CC. +# ----------------------------- +# Register pytrees types +# ----------------------------- +import jax -# from absl.testing import absltest -# import flaxlib +jax.tree_util.register_static(IndexMapping) - -# class TestFlaxlib(absltest.TestCase): - -# def test_flaxlib(self): -# self.assertEqual(flaxlib.sum_as_string(1, 2), '3') +del jax \ No newline at end of file diff --git a/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi b/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi new file mode 100644 index 0000000000..1ab90b0b7c --- /dev/null +++ b/flaxlib_src/src/flaxlib/flaxlib_cpp.pyi @@ -0,0 +1,55 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterator + +def sum_as_string(a: int, b: int) -> str: ... + +def create_index_ref( + ref_index: RefIndexMapping, index_mapping: IndexMapping +) -> IndexRefMapping: ... + +class IndexMapping: + def __init__(self, mapping: dict[int, int], /) -> None: ... + def __hash__(self) -> int: ... + def __getitem__(self, key: int) -> int: ... + def __len__(self) -> int: ... + def __contains__(self, key: int) -> bool: ... + def __iter__(self) -> Iterator[int]: ... + def items(self) -> Iterator[tuple[int, int]]: ... + +class RefIndexMapping: + def __init__(self, ref_mapping: dict[Any, int], /) -> None: ... + def __getitem__(self, key: Any) -> int: ... + def __contains__(self, key: Any) -> bool: ... + def __setitem__(self, key: Any, value: int) -> None: ... + def __delitem__(self, key: Any) -> None: ... + def __iter__(self) -> Iterator[Any]: ... + def __len__(self) -> int: ... + def items(self) -> Iterator[tuple[Any, int]]: ... + +class IndexRefMapping: + def __init__(self, mapping: dict[int, Any], /) -> None: ... + def __getitem__(self, key: int) -> Any: ... + def __contains__(self, key: int) -> bool: ... + def __setitem__(self, key: int, value: Any) -> None: ... + def __delitem__(self, key: int) -> None: ... + def __iter__(self) -> Iterator[int]: ... + def __len__(self) -> int: ... + def items(self) -> Iterator[tuple[int, Any]]: ... + +def _graph_flatten_top(ref_index: RefIndexMapping, node: Any) -> Any: ... +def _graph_flatten( + path: list, ref_index: RefIndexMapping, flat_state: list, node: Any +) -> Any: ... diff --git a/flaxlib_src/src/lib.cc b/flaxlib_src/src/lib.cc deleted file mode 100644 index c714588118..0000000000 --- a/flaxlib_src/src/lib.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include - -#include "nanobind/nanobind.h" -#include "nanobind/stl/string.h" - -namespace flaxlib { -std::string sum_as_string(int a, int b) { - return std::to_string(a + b); -} - -NB_MODULE(flaxlib, m) { - m.def("sum_as_string", &sum_as_string); -} -} // namespace flaxlib \ No newline at end of file diff --git a/flaxlib_src/src/lib.rs b/flaxlib_src/src/lib.rs deleted file mode 100644 index cadab2ef22..0000000000 --- a/flaxlib_src/src/lib.rs +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2024 The Flax Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use pyo3::prelude::*; - -/// Formats the sum of two numbers as string. -#[pyfunction] -fn sum_as_string(a: usize, b: usize) -> PyResult { - Ok((a + b).to_string()) -} - -/// A Python module implemented in Rust. -#[pymodule] -fn flaxlib(_py: Python, m: &Bound) -> PyResult<()> { - m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; - Ok(()) -} diff --git a/flaxlib_src/flaxlib.pyi b/tests/nnx/flaxlib_test.py similarity index 88% rename from flaxlib_src/flaxlib.pyi rename to tests/nnx/flaxlib_test.py index 505fd3d0f0..6b1d96347c 100644 --- a/flaxlib_src/flaxlib.pyi +++ b/tests/nnx/flaxlib_test.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -def sum_as_string(a: int, b: int) -> str: ... +from absl.testing import absltest + +if __name__ == '__main__': + absltest.main() \ No newline at end of file diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index a7bbf178cb..f3b47c1162 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -59,11 +59,15 @@ def __call__(self, x): class TestGraphUtils(absltest.TestCase): + def test_flatten_basic(self): + m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) + graphdef, state = nnx.split(m) + def test_flatten(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - refmap = nnx.graph.RefMap() + refmap = nnx.graph.RefIndexMapping({}) graphdef, state = nnx.graph.flatten(g, ref_index=refmap) state[0]['b'].raw_value = 2 @@ -326,7 +330,7 @@ def f(m: Foo): a = m.a b = m.b - ref_out_idx_out = nnx.graph.RefMap() + ref_out_idx_out = nnx.graph.RefIndexMapping({}) graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @@ -335,19 +339,19 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() + ref_in_idx_in = nnx.graph.RefIndexMapping({}) graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + idx_out_idx_in = nnx.graph.create_index_mapping( + idx_out_ref_in, ref_in_idx_in + ) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out static_out: nnx.graph.Static state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] + idx_out_idx_in: nnx.graph.IndexMapping graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in - ) + idx_in_ref_out = nnx.graph.create_index_ref(ref_out_idx_out, idx_out_idx_in) m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.a is b @@ -366,7 +370,7 @@ def f(m: Foo): a = m.a b = m.b - ref_out_idx_out = nnx.graph.RefMap[Any, int]() + ref_out_idx_out = nnx.graph.RefIndexMapping({}) graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @@ -375,19 +379,19 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() + ref_in_idx_in = nnx.graph.RefIndexMapping({}) graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + idx_out_idx_in = nnx.graph.create_index_mapping( + idx_out_ref_in, ref_in_idx_in + ) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out static_out: nnx.graph.Static state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] + idx_out_idx_in: nnx.graph.IndexMapping graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in - ) + idx_in_ref_out = nnx.graph.create_index_ref(ref_out_idx_out, idx_out_idx_in) m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.a is b @@ -403,7 +407,7 @@ def f(m: Foo): m = Foo() - ref_out_idx_out = nnx.graph.RefMap() + ref_out_idx_out = nnx.graph.RefIndexMapping({}) graphdef: nnx.graph.GraphDef[Foo] graphdef, state = nnx.graph.flatten(m, ref_index=ref_out_idx_out) @@ -412,19 +416,19 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): idx_out_ref_in: dict[int, Any] = {} m = nnx.graph.unflatten(graphdef, state, index_ref=idx_out_ref_in) f(m) - ref_in_idx_in = nnx.graph.RefMap[Any, int]() + ref_in_idx_in = nnx.graph.RefIndexMapping({}) graphdef, state = nnx.graph.flatten(m, ref_index=ref_in_idx_in) - idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + idx_out_idx_in = nnx.graph.create_index_mapping( + idx_out_ref_in, ref_in_idx_in + ) static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out static_out: nnx.graph.Static state, static_out = f_pure(graphdef, state) - idx_out_idx_in: dict[int, int] + idx_out_idx_in: nnx.graph.IndexMapping graphdef, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph.compose_mapping_reversed( - ref_out_idx_out, idx_out_idx_in - ) + idx_in_ref_out = nnx.graph.create_index_ref(ref_out_idx_out, idx_out_idx_in) m2 = nnx.graph.unflatten(graphdef, state, index_ref_cache=idx_in_ref_out) assert m2 is m assert m2.ref is m2 diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 920d71017b..2c210b18d0 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -63,6 +63,7 @@ echo "GH_VENV: $GH_VENV" echo "WHICH PYTHON: $(which python)" echo "jax: $(python -c 'import jax; print(jax.__version__)')" echo "flax: $(python -c 'import flax; print(flax.__version__)')" +echo "flax config: $(python -c 'from flax import config; print(config)')" echo "==========================" echo "" diff --git a/uv.lock b/uv.lock index a30155113e..9a9a8dea1e 100644 --- a/uv.lock +++ b/uv.lock @@ -2263,7 +2263,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.10.1" +version = "0.10.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2281,9 +2281,9 @@ dependencies = [ { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/07/24/f13f75810a00873f779625b4fff9419d09f95a56bedb01453ac2b4990ce8/orbax_checkpoint-0.10.1.tar.gz", hash = "sha256:aaf44f5a10ced74badc7fcaf8a2396e9047a20a61487ad5e8514e539d7992cd8", size = 230081 } +sdist = { url = "https://files.pythonhosted.org/packages/d1/06/c42e2f1563dbaaf5ed1464d7b634324fb9a2da04021073c45777e61af78d/orbax_checkpoint-0.10.2.tar.gz", hash = "sha256:e575ebe1f94e5cb6353ab8c9df81de0ca7cddc118645c3bfc17b8344f19d42f1", size = 248170 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/67/a175072cd7e5a215b12f39f4d9d891881a6220d75e30ae6480d05647bdf4/orbax_checkpoint-0.10.1-py3-none-any.whl", hash = "sha256:b4d7ae295d89a329c39109f945ff690d47c1db04eac644fa5316b2f42b5fa9e5", size = 328311 }, + { url = "https://files.pythonhosted.org/packages/61/19/ed366f8894923f3c8db0370e4bdd57ef843d68011dafa00d8175f4a66e1a/orbax_checkpoint-0.10.2-py3-none-any.whl", hash = "sha256:dcfc425674bd8d4934986143bd22a37cd634d034652c5d30d83c539ef8587941", size = 354306 }, ] [[package]]