diff --git a/src/lightning_utilities/core/apply_func.py b/src/lightning_utilities/core/apply_func.py new file mode 100644 index 00000000..0fcf8a8f --- /dev/null +++ b/src/lightning_utilities/core/apply_func.py @@ -0,0 +1,233 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +from collections import defaultdict, OrderedDict +from copy import deepcopy +from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union + + +def is_namedtuple(obj: object) -> bool: + # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 + return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + + +def is_dataclass_instance(obj: object) -> bool: + # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions + return dataclasses.is_dataclass(obj) and not isinstance(obj, type) + + +def apply_to_collection( + data: Any, + dtype: Union[type, Any, Tuple[Union[type, Any]]], + function: Callable, + *args: Any, + wrong_dtype: Optional[Union[type, Tuple[type, ...]]] = None, + include_none: bool = True, + **kwargs: Any, +) -> Any: + """Recursively applies a function to all elements of a certain dtype. + + Args: + data: the collection to apply the function to + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections + is of the ``wrong_dtype`` even if it is of type ``dtype`` + include_none: Whether to include an element if the output of ``function`` is ``None``. + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + The resulting collection + """ + # Breaking condition + if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): + return function(data, *args, **kwargs) + + elem_type = type(data) + + # Recursively apply to collection items + if isinstance(data, Mapping): + out = [] + for k, v in data.items(): + v = apply_to_collection( + v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs + ) + if include_none or v is not None: + out.append((k, v)) + if isinstance(data, defaultdict): + return elem_type(data.default_factory, OrderedDict(out)) + return elem_type(OrderedDict(out)) + + is_namedtuple_ = is_namedtuple(data) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + if is_namedtuple_ or is_sequence: + out = [] + for d in data: + v = apply_to_collection( + d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs + ) + if include_none or v is not None: + out.append(v) + return elem_type(*out) if is_namedtuple_ else elem_type(out) + + if is_dataclass_instance(data): + # make a deepcopy of the data, + # but do not deepcopy mapped fields since the computation would + # be wasted on values that likely get immediately overwritten + fields = {} + memo = {} + for field in dataclasses.fields(data): + field_value = getattr(data, field.name) + fields[field.name] = (field_value, field.init) + memo[id(field_value)] = field_value + result = deepcopy(data, memo=memo) + # apply function to each field + for field_name, (field_value, field_init) in fields.items(): + v = None + if field_init: + v = apply_to_collection( + field_value, + dtype, + function, + *args, + wrong_dtype=wrong_dtype, + include_none=include_none, + **kwargs, + ) + if not field_init or (not include_none and v is None): # retain old value + v = getattr(data, field_name) + try: + setattr(result, field_name, v) + except dataclasses.FrozenInstanceError as e: + raise ValueError( + "A frozen dataclass was passed to `apply_to_collection` but this is not allowed." + ) from e + return result + + # data is neither of dtype, nor a collection + return data + + +def apply_to_collections( + data1: Optional[Any], + data2: Optional[Any], + dtype: Union[type, Any, Tuple[Union[type, Any]]], + function: Callable, + *args: Any, + wrong_dtype: Optional[Union[type, Tuple[type]]] = None, + **kwargs: Any, +) -> Any: + """Zips two collections and applies a function to their items of a certain dtype. + + Args: + data1: The first collection + data2: The second collection + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections + is of the ``wrong_dtype`` even if it is of type ``dtype`` + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + The resulting collection + + Raises: + AssertionError: + If sequence collections have different data sizes. + """ + if data1 is None: + if data2 is None: + return + # in case they were passed reversed + data1, data2 = data2, None + + elem_type = type(data1) + + if isinstance(data1, dtype) and data2 is not None and (wrong_dtype is None or not isinstance(data1, wrong_dtype)): + return function(data1, data2, *args, **kwargs) + + if isinstance(data1, Mapping) and data2 is not None: + # use union because we want to fail if a key does not exist in both + zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()} + return elem_type( + { + k: apply_to_collections(*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for k, v in zipped.items() + } + ) + + is_namedtuple_ = is_namedtuple(data1) + is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str) + if (is_namedtuple_ or is_sequence) and data2 is not None: + assert len(data1) == len(data2), "Sequence collections have different sizes." + out = [ + apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for v1, v2 in zip(data1, data2) + ] + return elem_type(*out) if is_namedtuple_ else elem_type(out) + + if is_dataclass_instance(data1) and data2 is not None: + if not is_dataclass_instance(data2): + raise TypeError( + "Expected inputs to be dataclasses of the same type or to have identical fields" + f" but got input 1 of type {type(data1)} and input 2 of type {type(data2)}." + ) + if not ( + len(dataclasses.fields(data1)) == len(dataclasses.fields(data2)) + and all(map(lambda f1, f2: isinstance(f1, type(f2)), dataclasses.fields(data1), dataclasses.fields(data2))) + ): + raise TypeError("Dataclasses fields do not match.") + # make a deepcopy of the data, + # but do not deepcopy mapped fields since the computation would + # be wasted on values that likely get immediately overwritten + data = [data1, data2] + fields: List[dict] = [{}, {}] + memo: dict = {} + for i in range(len(data)): + for field in dataclasses.fields(data[i]): + field_value = getattr(data[i], field.name) + fields[i][field.name] = (field_value, field.init) + if i == 0: + memo[id(field_value)] = field_value + + result = deepcopy(data1, memo=memo) + + # apply function to each field + for ((field_name, (field_value1, field_init1)), (_, (field_value2, field_init2))) in zip( + fields[0].items(), fields[1].items() + ): + v = None + if field_init1 and field_init2: + v = apply_to_collections( + field_value1, + field_value2, + dtype, + function, + *args, + wrong_dtype=wrong_dtype, + **kwargs, + ) + if not field_init1 or not field_init2 or v is None: # retain old value + return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + try: + setattr(result, field_name, v) + except dataclasses.FrozenInstanceError as e: + raise ValueError( + "A frozen dataclass was passed to `apply_to_collections` but this is not allowed." + ) from e + return result + + return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) diff --git a/tests/unittests/__init__.py b/tests/unittests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/core/__init__.py b/tests/unittests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unittests/core/mocks.py b/tests/unittests/core/mocks.py new file mode 100644 index 00000000..adb5abdf --- /dev/null +++ b/tests/unittests/core/mocks.py @@ -0,0 +1,53 @@ +from typing import Iterable + +from lightning_utilities.core.imports import package_available + +if package_available("torch"): + import torch +else: + # minimal torch implementation to avoid installing torch in testing CI + class TensorMock: + def __init__(self, data): + self.data = data + + def __add__(self, other): + if isinstance(self.data, Iterable): + if isinstance(other, (int, float)): + return TensorMock([a + other for a in self.data]) + if isinstance(other, Iterable): + return TensorMock([a + b for a, b in zip(self, other)]) + return self.data + other + + def __mul__(self, other): + if isinstance(self.data, Iterable): + if isinstance(other, (int, float)): + return TensorMock([a * other for a in self.data]) + if isinstance(other, Iterable): + return TensorMock([a * b for a, b in zip(self, other)]) + return self.data * other + + def __iter__(self): + return iter(self.data) + + def __repr__(self): + return repr(self.data) + + def __eq__(self, other): + return self.data == other + + class TorchMock: + Tensor = TensorMock + + @staticmethod + def tensor(data): + return TensorMock(data) + + @staticmethod + def equal(a, b): + return a == b + + @staticmethod + def arange(*args): + return TensorMock(list(range(*args))) + + torch = TorchMock() diff --git a/tests/unittests/core/test_apply_func.py b/tests/unittests/core/test_apply_func.py new file mode 100644 index 00000000..47a474ba --- /dev/null +++ b/tests/unittests/core/test_apply_func.py @@ -0,0 +1,326 @@ +import dataclasses +import numbers +from collections import defaultdict, namedtuple, OrderedDict +from dataclasses import InitVar +from typing import Any, ClassVar, List, Optional + +import pytest +from unittests.core.mocks import torch + +from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections + + +@dataclasses.dataclass +class Feature: + input_ids: torch.Tensor + segment_ids: torch.Tensor + + def __eq__(self, o: object) -> bool: + if not isinstance(o, Feature): + return NotImplemented + return torch.equal(self.input_ids, o.input_ids) and torch.equal(self.segment_ids, o.segment_ids) + + +@dataclasses.dataclass +class ModelExample: + example_ids: List[str] + feature: Feature + label: torch.Tensor + some_constant: int = dataclasses.field(init=False) + + def __post_init__(self): + self.some_constant = 7 + + def __eq__(self, o: object) -> bool: + if not isinstance(o, ModelExample): + return NotImplemented + + return ( + self.example_ids == o.example_ids + and self.feature == o.feature + and torch.equal(self.label, o.label) + and self.some_constant == o.some_constant + ) + + +@dataclasses.dataclass +class WithClassVar: + class_var: ClassVar[int] = 0 + dummy: Any + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithClassVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + + return self.dummy == o.dummy + + +@dataclasses.dataclass +class WithInitVar: + dummy: Any + override: InitVar[Optional[Any]] = None + + def __post_init__(self, override: Optional[Any]): + if override is not None: + self.dummy = override + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithInitVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + + return self.dummy == o.dummy + + +@dataclasses.dataclass +class WithClassAndInitVar: + class_var: ClassVar[torch.Tensor] = torch.tensor(0) + dummy: Any + override: InitVar[Optional[Any]] = torch.tensor(1) + + def __post_init__(self, override: Optional[Any]): + if override is not None: + self.dummy = override + + def __eq__(self, o: object) -> bool: + if not isinstance(o, WithClassAndInitVar): + return NotImplemented + elif isinstance(self.dummy, torch.Tensor): + return torch.equal(self.dummy, o.dummy) + + return self.dummy == o.dummy + + +def test_recursive_application_to_collection(): + ntc = namedtuple("Foo", ["bar"]) + + model_example = ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=torch.tensor([4.0, 5.0, 6.0])), + label=torch.tensor([7.0, 8.0, 9.0]), + ) + + to_reduce = { + "a": torch.tensor([1.0]), # Tensor + "b": [torch.tensor([2.0])], # list + "c": (torch.tensor([100.0]),), # tuple + "d": ntc(bar=5.0), # named tuple + "f": "this_is_a_dummy_str", # string + "g": 12.0, # number + "h": Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=torch.tensor([4.0, 5.0, 6.0])), # dataclass + "i": model_example, # nested dataclass + "j": WithClassVar(torch.arange(3)), # dataclass with class variable + "k": WithInitVar("this_gets_overridden", torch.tensor([2.0])), # dataclass with init-only variable + "l": WithClassAndInitVar(model_example, None), # nested dataclass with class and init-only variables + } + + model_example_result = ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=torch.tensor([8.0, 10.0, 12.0])), + label=torch.tensor([14.0, 16.0, 18.0]), + ) + + expected_result = { + "a": torch.tensor([2.0]), + "b": [torch.tensor([4.0])], + "c": (torch.tensor([200.0]),), + "d": ntc(bar=10), + "f": "this_is_a_dummy_str", + "g": 24.0, + "h": Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=torch.tensor([8.0, 10.0, 12.0])), + "i": model_example_result, + "j": WithClassVar(torch.arange(0, 6, 2)), + "k": WithInitVar(torch.tensor([4.0])), + "l": WithClassAndInitVar(model_example_result, None), + } + + reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number), lambda x: x * 2) + + assert isinstance(reduced, dict), "Type Consistency of dict not preserved" + assert all(x in reduced for x in to_reduce), "Not all entries of the dict were preserved" + assert all( + isinstance(reduced[k], type(expected_result[k])) for k in to_reduce + ), "At least one type was not correctly preserved" + + assert isinstance(reduced["a"], torch.Tensor), "Reduction Result of a Tensor should be a Tensor" + assert torch.equal(expected_result["a"], reduced["a"]), "Reduction of a tensor does not yield the expected value" + + assert isinstance(reduced["b"], list), "Reduction Result of a list should be a list" + assert all( + torch.equal(x, y) for x, y in zip(reduced["b"], expected_result["b"]) + ), "At least one value of list reduction did not come out as expected" + + assert isinstance(reduced["c"], tuple), "Reduction Result of a tuple should be a tuple" + assert all( + torch.equal(x, y) for x, y in zip(reduced["c"], expected_result["c"]) + ), "At least one value of tuple reduction did not come out as expected" + + assert isinstance(reduced["d"], ntc), "Type Consistency for named tuple not given" + assert isinstance( + reduced["d"].bar, numbers.Number + ), "Failure in type promotion while reducing fields of named tuples" + assert reduced["d"].bar == expected_result["d"].bar + + assert isinstance(reduced["f"], str), "A string should not be reduced" + assert reduced["f"] == expected_result["f"], "String not preserved during reduction" + + assert isinstance(reduced["g"], numbers.Number), "Reduction of a number should result in a number" + assert reduced["g"] == expected_result["g"], "Reduction of a number did not yield the desired result" + + def _assert_dataclass_reduction(actual, expected, dataclass_type: str = ""): + assert dataclasses.is_dataclass(actual) and not isinstance( + actual, type + ), f"Reduction of a {dataclass_type} dataclass should result in a dataclass" + for field in dataclasses.fields(actual): + if dataclasses.is_dataclass(field.type): + _assert_dataclass_reduction(getattr(actual, field.name), getattr(expected, field.name), "nested") + assert actual == expected, f"Reduction of a {dataclass_type} dataclass did not yield the desired result" + + _assert_dataclass_reduction(reduced["h"], expected_result["h"]) + + _assert_dataclass_reduction(reduced["i"], expected_result["i"]) + + dataclass_type = "ClassVar-containing" + _assert_dataclass_reduction(reduced["j"], expected_result["j"], dataclass_type) + assert WithClassVar.class_var == 0, f"Reduction of a {dataclass_type} dataclass should not change the class var" + + _assert_dataclass_reduction(reduced["k"], expected_result["k"], "InitVar-containing") + + dataclass_type = "Class-and-InitVar-containing" + _assert_dataclass_reduction(reduced["l"], expected_result["l"], dataclass_type) + assert torch.equal( + WithClassAndInitVar.class_var, torch.tensor(0) + ), f"Reduction of a {dataclass_type} dataclass should not change the class var" + + # mapping support + reduced = apply_to_collection({"a": 1, "b": 2}, int, lambda x: str(x)) + assert reduced == {"a": "1", "b": "2"} + reduced = apply_to_collection(OrderedDict([("b", 2), ("a", 1)]), int, lambda x: str(x)) + assert reduced == OrderedDict([("b", "2"), ("a", "1")]) + + # custom mappings + class _CustomCollection(dict): + def __init__(self, initial_dict): + super().__init__(initial_dict) + + to_reduce = _CustomCollection({"a": 1, "b": 2, "c": 3}) + reduced = apply_to_collection(to_reduce, int, lambda x: str(x)) + assert reduced == _CustomCollection({"a": "1", "b": "2", "c": "3"}) + + # defaultdict + to_reduce = defaultdict(int, {"a": 1, "b": 2, "c": 3}) + reduced = apply_to_collection(to_reduce, int, lambda x: str(x)) + assert reduced == defaultdict(int, {"a": "1", "b": "2", "c": "3"}) + + +def test_apply_to_collection_include_none(): + to_reduce = [1, 2, 3.4, 5.6, 7, (8, 9.1, {10: 10})] + + def fn(x): + if isinstance(x, float): + return x + + reduced = apply_to_collection(to_reduce, (int, float), fn) + assert reduced == [None, None, 3.4, 5.6, None, (None, 9.1, {10: None})] + + reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False) + assert reduced == [3.4, 5.6, (9.1, {})] + + +def test_apply_to_collections(): + to_reduce_1 = {"a": {"b": [1, 2]}, "c": 5} + to_reduce_2 = {"a": {"b": [3, 4]}, "c": 6} + + def fn(a, b): + return a + b + + # basic test + reduced = apply_to_collections(to_reduce_1, to_reduce_2, int, fn) + assert reduced == {"a": {"b": [4, 6]}, "c": 11} + + with pytest.raises(KeyError): + # strict mode - if a key does not exist in both we fail + apply_to_collections({**to_reduce_2, "d": "foo"}, to_reduce_1, float, fn) + + # multiple dtypes + reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn) + assert reduced == {"a": {"b": [1, 2, 3, 4]}, "c": 11} + + # wrong dtype + reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn, wrong_dtype=int) + assert reduced == {"a": {"b": [1, 2, 3, 4]}, "c": 5} + + # list takes precedence because it is the type of data1 + reduced = apply_to_collections([1, 2, 3], [4], (int, list), fn) + assert reduced == [1, 2, 3, 4] + + # different sizes + with pytest.raises(AssertionError, match="Sequence collections have different sizes"): + apply_to_collections([[1, 2], [3]], [4], int, fn) + + def fn(a, b): + return a.keys() | b.keys() + + # base case + reduced = apply_to_collections(to_reduce_1, to_reduce_2, dict, fn) + assert reduced == {"a", "c"} + + # type conversion + to_reduce = [(1, 2), (3, 4)] + reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x)) + assert reduced == [(2, 4), (6, 8)] + + # named tuple + foo = namedtuple("Foo", ["bar"]) + to_reduce = [foo(1), foo(2), foo(3)] + reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x)) + assert reduced == [foo(2), foo(4), foo(6)] + + # passing none + reduced1 = apply_to_collections([1, 2, 3], None, int, lambda x: x * x) + reduced2 = apply_to_collections(None, [1, 2, 3], int, lambda x: x * x) + assert reduced1 == reduced2 == [1, 4, 9] + reduced = apply_to_collections(None, None, int, lambda x: x * x) + assert reduced is None + + +def test_apply_to_collections_dataclass(): + to_reduce_1 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=torch.tensor([4.0, 5.0, 6.0])) + to_reduce_2 = Feature(input_ids=torch.tensor([1.0, 2.0, 3.0]), segment_ids=torch.tensor([4.0, 5.0, 6.0])) + + def fn(a, b): + return a + b + + reduced = apply_to_collections(to_reduce_1, to_reduce_2, torch.Tensor, fn) + assert reduced == Feature(input_ids=torch.tensor([2.0, 4.0, 6.0]), segment_ids=torch.tensor([8.0, 10.0, 12.0])) + + model_example = ModelExample( + example_ids=["i-1", "i-2", "i-3"], + feature=to_reduce_1, + label=torch.tensor([7.0, 8.0, 9.0]), + ) + + # different types + with pytest.raises(TypeError, match="Expected inputs to be dataclasses of the same type"): + apply_to_collections(to_reduce_1, [1, 2], torch.Tensor, fn) + + # unmatched fields + with pytest.raises(TypeError, match="Dataclasses fields do not match"): + apply_to_collections(to_reduce_1, model_example, torch.Tensor, fn) + + classvar = WithClassVar(torch.arange(3)) # dataclass with same number but different type of fields + with pytest.raises(TypeError, match="Dataclasses fields do not match"): + apply_to_collections(to_reduce_1, classvar, torch.Tensor, fn) + + +def test_apply_to_collection_frozen_dataclass(): + @dataclasses.dataclass(frozen=True) + class Foo: + input: int + + foo = Foo(0) + with pytest.raises(ValueError, match="frozen dataclass was passed"): + apply_to_collection(foo, int, lambda x: x + 1)