From 5df2022634ec6421b36db546af9c8acff7b12dc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 25 Jan 2023 13:28:57 +0100 Subject: [PATCH] feat(datatype): implement `Mapping` abstract base class for `StructType` --- ibis/expr/datatypes/core.py | 9 ++++++- ibis/tests/expr/test_datatypes.py | 39 +++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index 9c89c7fc85fa..8173a459a74d 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -2,6 +2,7 @@ import numbers from abc import abstractmethod +from collections.abc import Iterator, Mapping from typing import Any, Iterable, NamedTuple import numpy as np @@ -639,7 +640,7 @@ def to_integer_type(self): @public -class Struct(DataType): +class Struct(DataType, Mapping): """Structured values.""" fields = frozendict_of(instance_of(str), datatype) @@ -677,6 +678,12 @@ def types(self) -> tuple[DataType, ...]: """Return the types of the struct's fields.""" return tuple(self.fields.values()) + def __len__(self) -> int: + return len(self.fields) + + def __iter__(self) -> Iterator[str]: + return iter(self.fields) + def __getitem__(self, key: str) -> DataType: return self.fields[key] diff --git a/ibis/tests/expr/test_datatypes.py b/ibis/tests/expr/test_datatypes.py index 97879cbd5b55..1309761ddb75 100644 --- a/ibis/tests/expr/test_datatypes.py +++ b/ibis/tests/expr/test_datatypes.py @@ -141,6 +141,45 @@ def test_struct_with_string_types(): ) +def test_struct_mapping_api(): + s = dt.Struct( + { + 'a': 'map', + 'b': 'array>>', + 'c': 'array', + 'd': 'int8', + } + ) + + assert s['a'] == dt.Map(dt.double, dt.string) + assert s['b'] == dt.Array(dt.Map(dt.string, dt.Array(dt.int32))) + assert s['c'] == dt.Array(dt.string) + assert s['d'] == dt.int8 + + assert 'a' in s + assert 'e' not in s + assert len(s) == 4 + assert tuple(s) == s.names + assert tuple(s.keys()) == s.names + assert tuple(s.values()) == s.types + assert tuple(s.items()) == tuple(zip(s.names, s.types)) + + s1 = s.copy() + s2 = dt.Struct( + { + 'a': 'map', + 'b': 'array>>', + 'c': 'array', + } + ) + assert s == s1 + assert s != s2 + + # doesn't support item assignment + with pytest.raises(TypeError): + s['e'] = dt.int8 + + @pytest.mark.parametrize( 'case', [