From 2832745d074512bdc6e77f02afbbecc66535962d Mon Sep 17 00:00:00 2001 From: aandres Date: Mon, 14 Oct 2024 10:56:18 +0100 Subject: [PATCH] Add support for duration Test with less messages Better check on view --- .github/workflows/ci.yaml | 6 +- .gitignore | 1 + CHANGELOG.md | 8 ++ docs/contributing.md | 2 +- docs/types.md | 1 + docs/usage.md | 3 +- protarrow/arrow_to_proto.py | 32 ++++++ protarrow/cast_to_proto.py | 3 + protarrow/common.py | 1 + protarrow/proto_to_arrow.py | 12 ++ protos/bench.proto | 211 ++++++++++++++++++------------------ scripts/generate_proto.py | 1 + scripts/template.proto.in | 1 + tests/random_generator.py | 49 ++++++++- tests/test_conversion.py | 80 +++++++++----- 15 files changed, 269 insertions(+), 142 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ad533f5..4f8ce38 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -15,10 +15,10 @@ jobs: fail-fast: false steps: - name: Checkout sources - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: "pip" @@ -33,7 +33,7 @@ jobs: run: tox - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 if: "matrix.python-version == '3.11'" with: fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 001274a..1952bea 100644 --- a/.gitignore +++ b/.gitignore @@ -131,3 +131,4 @@ dmypy.json # Custom /protarrow_protos +.idea diff --git a/CHANGELOG.md b/CHANGELOG.md index 77aa012..15b0e1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). +## [v0.7.0](https://github.com/tradewelltech/protarrow/releases/tag/v0.7.0) - 2024-10-14 + +[Compare with v0.6.0](https://github.com/tradewelltech/protarrow/compare/v0.6.0...v0.7.0) + +### Added + +- Add support for duration ([c9ab1e2](https://github.com/tradewelltech/protarrow/commit/c9ab1e203712a503dec6bbbd58d340ca37542de5) by aandres). + ## [v0.6.0](https://github.com/tradewelltech/protarrow/releases/tag/v0.6.0) - 2024-09-12 [Compare with v0.5.2](https://github.com/tradewelltech/protarrow/compare/v0.5.2...v0.6.0) diff --git a/docs/contributing.md b/docs/contributing.md index 485f211..fe3d086 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -47,7 +47,7 @@ git-changelog -io CHANGELOG.md For new release, first prepare the change log, push and merge it. ```shell -git-changelog -bio CHANGELOG.md +git-changelog --bump=auto -io CHANGELOG.md ``` Then tag and push: diff --git a/docs/types.md b/docs/types.md index b220745..b83d1f1 100644 --- a/docs/types.md +++ b/docs/types.md @@ -41,6 +41,7 @@ | google.protobuf.UInt64Value | uint64 | | | google.type.Date | date32() | | | google.type.TimeOfDay | **time64**/time32 | Unit and type are configurable | +| google.type.Duration | duration("ns") | Unit is configurable | ## Nullability diff --git a/docs/usage.md b/docs/usage.md index 60518e3..1d22485 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -57,13 +57,14 @@ my_proto_1 = message_extractor.read_table_row(table, 1) ## Customize arrow type -The arrow type for `Enum`, `Timestamp` and `TimeOfDay` can be configured: +The arrow type for `Enum`, `Timestamp` and `TimeOfDay` and `Duration` can be configured: ```python config = protarrow.ProtarrowConfig( enum_type=pa.int32(), timestamp_type=pa.timestamp("ms", "America/New_York"), time_of_day_type=pa.time32("ms"), + duration_type=pa.duration("s"), ) record_batch = protarrow.messages_to_record_batch(my_protos, MyProto, config) ``` diff --git a/protarrow/arrow_to_proto.py b/protarrow/arrow_to_proto.py index 963345d..ea27eda 100644 --- a/protarrow/arrow_to_proto.py +++ b/protarrow/arrow_to_proto.py @@ -5,6 +5,7 @@ import pyarrow as pa from google.protobuf.descriptor import Descriptor, EnumDescriptor, FieldDescriptor +from google.protobuf.duration_pb2 import Duration from google.protobuf.internal.containers import MessageMap from google.protobuf.message import Message from google.protobuf.timestamp_pb2 import Timestamp @@ -75,6 +76,29 @@ def _time_64_ns_scalar_to_proto(scalar: pa.Time64Scalar) -> TimeOfDay: ) +def _duration_ns_scalar_to_proto(scalar: pa.DurationScalar) -> Duration: + total_nanos = scalar.value + return Duration( + nanos=total_nanos % 1_000_000_000, seconds=(total_nanos // 1_000_000_000) + ) + + +def _duration_us_scalar_to_proto(scalar: pa.DurationScalar) -> Duration: + total_us = scalar.value + return Duration( + nanos=(total_us % 1_000_000) * 1_000, seconds=(total_us // 1_000_000) + ) + + +def _duration_ms_scalar_to_proto(scalar: pa.DurationScalar) -> Duration: + total_us = scalar.value + return Duration(nanos=(total_us % 1_000) * 1_000_000, seconds=(total_us // 1_000)) + + +def _duration_s_scalar_to_proto(scalar: pa.DurationScalar) -> Duration: + return Duration(seconds=scalar.value) + + def _time_64_us_scalar_to_proto(scalar: pa.Time64Scalar) -> TimeOfDay: total_us = scalar.value return TimeOfDay( @@ -112,6 +136,13 @@ def _time_32_s_scalar_to_proto(scalar: pa.Time32Scalar) -> TimeOfDay: pa.time32("s"): _time_32_s_scalar_to_proto, } +DURATION_CONVERTERS = { + "ns": _duration_ns_scalar_to_proto, + "us": _duration_us_scalar_to_proto, + "ms": _duration_ms_scalar_to_proto, + "s": _duration_s_scalar_to_proto, +} + TIMESTAMP_CONVERTERS = { "ns": _timestamp_ns_scalar_to_proto, "us": _timestamp_us_scalar_to_proto, @@ -123,6 +154,7 @@ def _time_32_s_scalar_to_proto(scalar: pa.Time32Scalar) -> TimeOfDay: Timestamp.DESCRIPTOR: lambda data_type: TIMESTAMP_CONVERTERS[data_type.unit], Date.DESCRIPTOR: lambda _: _date_scalar_to_proto, TimeOfDay.DESCRIPTOR: TIME_OF_DAY_CONVERTERS.__getitem__, + Duration.DESCRIPTOR: lambda data_type: DURATION_CONVERTERS[data_type.unit], } NULLABLE_TYPES = ( diff --git a/protarrow/cast_to_proto.py b/protarrow/cast_to_proto.py index c0f6868..cf92f88 100644 --- a/protarrow/cast_to_proto.py +++ b/protarrow/cast_to_proto.py @@ -3,6 +3,7 @@ import pyarrow as pa import pyarrow.compute as pc from google.protobuf.descriptor import Descriptor, FieldDescriptor +from google.protobuf.duration_pb2 import Duration from google.protobuf.message import Message from google.protobuf.timestamp_pb2 import Timestamp from google.type.timeofday_pb2 import TimeOfDay @@ -52,6 +53,8 @@ def _cast_flat_array( return array.cast(config.time_of_day_type) elif field_descriptor.message_type == Timestamp.DESCRIPTOR: return array.cast(config.timestamp_type) + elif field_descriptor.message_type == Duration.DESCRIPTOR: + return array.cast(config.duration_type) elif field_descriptor.message_type in _PROTO_DESCRIPTOR_TO_PYARROW: return array.cast( _PROTO_DESCRIPTOR_TO_PYARROW[field_descriptor.message_type] diff --git a/protarrow/common.py b/protarrow/common.py index cab72e5..f750ca3 100644 --- a/protarrow/common.py +++ b/protarrow/common.py @@ -20,6 +20,7 @@ class ProtarrowConfig: enum_type: pa.DataType = pa.int32() timestamp_type: pa.TimestampType = pa.timestamp("ns", "UTC") time_of_day_type: Union[pa.Time64Type, pa.Time32Type] = pa.time64("ns") + duration_type: pa.DurationType = pa.duration("ns") list_nullable: bool = False map_nullable: bool = False list_value_nullable: bool = False diff --git a/protarrow/proto_to_arrow.py b/protarrow/proto_to_arrow.py index 8c7d55c..403c0a0 100644 --- a/protarrow/proto_to_arrow.py +++ b/protarrow/proto_to_arrow.py @@ -19,6 +19,7 @@ import pyarrow.compute as pc from google.protobuf.descriptor import Descriptor, EnumDescriptor, FieldDescriptor from google.protobuf.descriptor_pb2 import FieldDescriptorProto +from google.protobuf.duration_pb2 import Duration from google.protobuf.internal.containers import MessageMap, RepeatedScalarFieldContainer from google.protobuf.message import Message from google.protobuf.timestamp_pb2 import Timestamp @@ -127,6 +128,13 @@ def _proto_date_to_py_date(proto_date: Date) -> datetime.date: "ns": _time_of_day_to_nanos, } +_DURATION_CONVERTERS = { + "s": Duration.ToSeconds, + "ms": Duration.ToMilliseconds, + "us": Duration.ToMicroseconds, + "ns": Duration.ToNanoseconds, +} + @dataclasses.dataclass(frozen=True) class FlattenedIterable(collections.abc.Iterable): @@ -286,6 +294,8 @@ def field_descriptor_to_data_type( return config.timestamp_type elif field_descriptor.message_type == TimeOfDay.DESCRIPTOR: return config.time_of_day_type + elif field_descriptor.message_type == Duration.DESCRIPTOR: + return config.duration_type elif field_descriptor.type == FieldDescriptorProto.TYPE_MESSAGE: try: return _PROTO_DESCRIPTOR_TO_PYARROW[field_descriptor.message_type] @@ -314,6 +324,8 @@ def _get_converter( ) -> Optional[Callable[[Any], Any]]: if field_descriptor.message_type == Timestamp.DESCRIPTOR: return _TIMESTAMP_CONVERTERS[config.timestamp_type.unit] + elif field_descriptor.message_type == Duration.DESCRIPTOR: + return _DURATION_CONVERTERS[config.duration_type.unit] elif field_descriptor.message_type == TimeOfDay.DESCRIPTOR: return _TIME_OF_DAY_CONVERTERS[config.time_of_day_type.unit] elif field_descriptor.type == FieldDescriptorProto.TYPE_MESSAGE: diff --git a/protos/bench.proto b/protos/bench.proto index d65c498..5c98502 100644 --- a/protos/bench.proto +++ b/protos/bench.proto @@ -3,6 +3,7 @@ syntax = "proto3"; +import "google/protobuf/duration.proto"; import "google/protobuf/empty.proto"; import "google/protobuf/timestamp.proto"; import "google/protobuf/wrappers.proto"; @@ -51,120 +52,124 @@ message ExampleMessage { google.type.Date date_value = 27; google.type.TimeOfDay time_of_day_value = 28; google.protobuf.Empty empty_value = 29; + google.protobuf.Duration duration_value = 30; // Repeated values - repeated double double_values = 30; - repeated float float_values = 31; - repeated int32 int32_values = 32; - repeated int64 int64_values = 33; - repeated uint32 uint32_values = 34; - repeated uint64 uint64_values = 35; - repeated sint32 sint32_values = 36; - repeated sint64 sint64_values = 37; - repeated fixed32 fixed32_values = 38; - repeated fixed64 fixed64_values = 39; - repeated sfixed32 sfixed32_values = 40; - repeated sfixed64 sfixed64_values = 41; - repeated bool bool_values = 42; - repeated string string_values = 43; - repeated bytes bytes_values = 44; - repeated google.protobuf.DoubleValue wrapped_double_values = 45; - repeated google.protobuf.FloatValue wrapped_float_values = 46; - repeated google.protobuf.Int32Value wrapped_int32_values = 47; - repeated google.protobuf.Int64Value wrapped_int64_values = 48; - repeated google.protobuf.UInt32Value wrapped_uint32_values = 49; - repeated google.protobuf.UInt64Value wrapped_uint64_values = 50; - repeated google.protobuf.BoolValue wrapped_bool_values = 51; - repeated google.protobuf.StringValue wrapped_string_values = 52; - repeated google.protobuf.BytesValue wrapped_bytes_values = 53; - repeated ExampleEnum example_enum_values = 54; - repeated google.protobuf.Timestamp timestamp_values = 55; - repeated google.type.Date date_values = 56; - repeated google.type.TimeOfDay time_of_day_values = 57; - repeated google.protobuf.Empty empty_values = 58; + repeated double double_values = 31; + repeated float float_values = 32; + repeated int32 int32_values = 33; + repeated int64 int64_values = 34; + repeated uint32 uint32_values = 35; + repeated uint64 uint64_values = 36; + repeated sint32 sint32_values = 37; + repeated sint64 sint64_values = 38; + repeated fixed32 fixed32_values = 39; + repeated fixed64 fixed64_values = 40; + repeated sfixed32 sfixed32_values = 41; + repeated sfixed64 sfixed64_values = 42; + repeated bool bool_values = 43; + repeated string string_values = 44; + repeated bytes bytes_values = 45; + repeated google.protobuf.DoubleValue wrapped_double_values = 46; + repeated google.protobuf.FloatValue wrapped_float_values = 47; + repeated google.protobuf.Int32Value wrapped_int32_values = 48; + repeated google.protobuf.Int64Value wrapped_int64_values = 49; + repeated google.protobuf.UInt32Value wrapped_uint32_values = 50; + repeated google.protobuf.UInt64Value wrapped_uint64_values = 51; + repeated google.protobuf.BoolValue wrapped_bool_values = 52; + repeated google.protobuf.StringValue wrapped_string_values = 53; + repeated google.protobuf.BytesValue wrapped_bytes_values = 54; + repeated ExampleEnum example_enum_values = 55; + repeated google.protobuf.Timestamp timestamp_values = 56; + repeated google.type.Date date_values = 57; + repeated google.type.TimeOfDay time_of_day_values = 58; + repeated google.protobuf.Empty empty_values = 59; + repeated google.protobuf.Duration duration_values = 60; // Map with int32 keys - map double_int32_map = 88; - map float_int32_map = 89; - map int32_int32_map = 90; - map int64_int32_map = 91; - map uint32_int32_map = 92; - map uint64_int32_map = 93; - map sint32_int32_map = 94; - map sint64_int32_map = 95; - map fixed32_int32_map = 96; - map fixed64_int32_map = 97; - map sfixed32_int32_map = 98; - map sfixed64_int32_map = 99; - map bool_int32_map = 100; - map string_int32_map = 101; - map bytes_int32_map = 102; - map wrapped_double_int32_map = 103; - map wrapped_float_int32_map = 104; - map wrapped_int32_int32_map = 105; - map wrapped_int64_int32_map = 106; - map wrapped_uint32_int32_map = 107; - map wrapped_uint64_int32_map = 108; - map wrapped_bool_int32_map = 109; - map wrapped_string_int32_map = 110; - map wrapped_bytes_int32_map = 111; - map example_enum_int32_map = 112; - map timestamp_int32_map = 113; - map date_int32_map = 114; - map time_of_day_int32_map = 115; - map empty_int32_map = 116; + map double_int32_map = 91; + map float_int32_map = 92; + map int32_int32_map = 93; + map int64_int32_map = 94; + map uint32_int32_map = 95; + map uint64_int32_map = 96; + map sint32_int32_map = 97; + map sint64_int32_map = 98; + map fixed32_int32_map = 99; + map fixed64_int32_map = 100; + map sfixed32_int32_map = 101; + map sfixed64_int32_map = 102; + map bool_int32_map = 103; + map string_int32_map = 104; + map bytes_int32_map = 105; + map wrapped_double_int32_map = 106; + map wrapped_float_int32_map = 107; + map wrapped_int32_int32_map = 108; + map wrapped_int64_int32_map = 109; + map wrapped_uint32_int32_map = 110; + map wrapped_uint64_int32_map = 111; + map wrapped_bool_int32_map = 112; + map wrapped_string_int32_map = 113; + map wrapped_bytes_int32_map = 114; + map example_enum_int32_map = 115; + map timestamp_int32_map = 116; + map date_int32_map = 117; + map time_of_day_int32_map = 118; + map empty_int32_map = 119; + map duration_int32_map = 120; // Map with string keys - map double_string_map = 117; - map float_string_map = 118; - map int32_string_map = 119; - map int64_string_map = 120; - map uint32_string_map = 121; - map uint64_string_map = 122; - map sint32_string_map = 123; - map sint64_string_map = 124; - map fixed32_string_map = 125; - map fixed64_string_map = 126; - map sfixed32_string_map = 127; - map sfixed64_string_map = 128; - map bool_string_map = 129; - map string_string_map = 130; - map bytes_string_map = 131; - map wrapped_double_string_map = 132; - map wrapped_float_string_map = 133; - map wrapped_int32_string_map = 134; - map wrapped_int64_string_map = 135; - map wrapped_uint32_string_map = 136; - map wrapped_uint64_string_map = 137; - map wrapped_bool_string_map = 138; - map wrapped_string_string_map = 139; - map wrapped_bytes_string_map = 140; - map example_enum_string_map = 141; - map timestamp_string_map = 142; - map date_string_map = 143; - map time_of_day_string_map = 144; - map empty_string_map = 145; + map double_string_map = 121; + map float_string_map = 122; + map int32_string_map = 123; + map int64_string_map = 124; + map uint32_string_map = 125; + map uint64_string_map = 126; + map sint32_string_map = 127; + map sint64_string_map = 128; + map fixed32_string_map = 129; + map fixed64_string_map = 130; + map sfixed32_string_map = 131; + map sfixed64_string_map = 132; + map bool_string_map = 133; + map string_string_map = 134; + map bytes_string_map = 135; + map wrapped_double_string_map = 136; + map wrapped_float_string_map = 137; + map wrapped_int32_string_map = 138; + map wrapped_int64_string_map = 139; + map wrapped_uint32_string_map = 140; + map wrapped_uint64_string_map = 141; + map wrapped_bool_string_map = 142; + map wrapped_string_string_map = 143; + map wrapped_bytes_string_map = 144; + map example_enum_string_map = 145; + map timestamp_string_map = 146; + map date_string_map = 147; + map time_of_day_string_map = 148; + map empty_string_map = 149; + map duration_string_map = 150; // Optional - optional double optional_double_value = 175; - optional float optional_float_value = 176; - optional int32 optional_int32_value = 177; - optional int64 optional_int64_value = 178; - optional uint32 optional_uint32_value = 179; - optional uint64 optional_uint64_value = 180; - optional sint32 optional_sint32_value = 181; - optional sint64 optional_sint64_value = 182; - optional fixed32 optional_fixed32_value = 183; - optional fixed64 optional_fixed64_value = 184; - optional sfixed32 optional_sfixed32_value = 185; - optional sfixed64 optional_sfixed64_value = 186; - optional bool optional_bool_value = 187; - optional string optional_string_value = 188; - optional bytes optional_bytes_value = 189; - optional ExampleEnum optional_example_enum_value = 199; + optional double optional_double_value = 181; + optional float optional_float_value = 182; + optional int32 optional_int32_value = 183; + optional int64 optional_int64_value = 184; + optional uint32 optional_uint32_value = 185; + optional uint64 optional_uint64_value = 186; + optional sint32 optional_sint32_value = 187; + optional sint64 optional_sint64_value = 188; + optional fixed32 optional_fixed32_value = 189; + optional fixed64 optional_fixed64_value = 190; + optional sfixed32 optional_sfixed32_value = 191; + optional sfixed64 optional_sfixed64_value = 192; + optional bool optional_bool_value = 193; + optional string optional_string_value = 194; + optional bytes optional_bytes_value = 195; + optional ExampleEnum optional_example_enum_value = 205; } message NestedExampleMessage { diff --git a/scripts/generate_proto.py b/scripts/generate_proto.py index a5c2b17..a8ea8d6 100644 --- a/scripts/generate_proto.py +++ b/scripts/generate_proto.py @@ -66,6 +66,7 @@ def can_be_optional(self) -> bool: TypeTemplate.logical("google.type.Date"), TypeTemplate.logical("google.type.TimeOfDay"), TypeTemplate.logical("google.protobuf.Empty"), + TypeTemplate.logical("google.protobuf.Duration"), ] MAP_KEYS = ["int32", "string"] diff --git a/scripts/template.proto.in b/scripts/template.proto.in index e196c67..abefd88 100644 --- a/scripts/template.proto.in +++ b/scripts/template.proto.in @@ -3,6 +3,7 @@ syntax = "proto3"; +import "google/protobuf/duration.proto"; import "google/protobuf/empty.proto"; import "google/protobuf/timestamp.proto"; import "google/protobuf/wrappers.proto"; diff --git a/tests/random_generator.py b/tests/random_generator.py index 4eaca99..7d66471 100644 --- a/tests/random_generator.py +++ b/tests/random_generator.py @@ -4,11 +4,13 @@ import typing from google.protobuf.descriptor import EnumDescriptor, FieldDescriptor +from google.protobuf.duration_pb2 import Duration from google.protobuf.message import Message from google.protobuf.timestamp_pb2 import Timestamp from google.type.date_pb2 import Date from google.type.timeofday_pb2 import TimeOfDay +import protarrow from protarrow.common import M from protarrow.proto_to_arrow import is_map @@ -32,6 +34,13 @@ def random_timestamp() -> Timestamp: ) +def random_duration() -> Duration: + return Duration( + seconds=random.randint(-9223372036, 9223372035), + nanos=random.randint(0, 999_999_999), + ) + + def random_date() -> Date: date = datetime.date.min + datetime.timedelta(days=random.randint(0, 3652058)) return Date(year=date.year, month=date.month, day=date.day) @@ -65,6 +74,7 @@ def random_time_of_day() -> TimeOfDay: Date.DESCRIPTOR: random_date, Timestamp.DESCRIPTOR: random_timestamp, TimeOfDay.DESCRIPTOR: random_time_of_day, + Duration.DESCRIPTOR: random_duration, } @@ -145,11 +155,35 @@ def _generate_enum(enum: EnumDescriptor) -> int: return random.choice(enum.values).index -def truncate_nanos(message: Message, timestamp_unit: str, time_unit: str) -> Message: +def truncate_messages( + messages: list[Message], config: protarrow.ProtarrowConfig +) -> list[Message]: + return [ + truncate_nanos( + m, + config.timestamp_type.unit, + config.time_of_day_type.unit, + config.duration_type.unit, + ) + for m in messages + ] + + +def truncate_nanos( + message: Message, + timestamp_unit: str, + time_unit: str, + duration_unit: str, +) -> Message: if message.DESCRIPTOR == Timestamp.DESCRIPTOR: message.nanos = ( message.nanos // UNIT_IN_NANOS[timestamp_unit] ) * UNIT_IN_NANOS[timestamp_unit] + elif message.DESCRIPTOR == Duration.DESCRIPTOR: + message.nanos = (message.nanos // UNIT_IN_NANOS[duration_unit]) * UNIT_IN_NANOS[ + duration_unit + ] + elif message.DESCRIPTOR == TimeOfDay.DESCRIPTOR: message.nanos = (message.nanos // UNIT_IN_NANOS[time_unit]) * UNIT_IN_NANOS[ time_unit @@ -168,13 +202,20 @@ def truncate_nanos(message: Message, timestamp_unit: str, time_unit: str) -> Mes == FieldDescriptor.TYPE_MESSAGE ): for key, value in field_value.items(): - truncate_nanos(value, timestamp_unit, time_unit) + truncate_nanos( + value, timestamp_unit, time_unit, duration_unit + ) else: for item in field_value: - truncate_nanos(item, timestamp_unit, time_unit) + truncate_nanos( + item, timestamp_unit, time_unit, duration_unit + ) elif message.HasField(field.name): truncate_nanos( - getattr(message, field.name), timestamp_unit, time_unit + getattr(message, field.name), + timestamp_unit, + time_unit, + duration_unit, ) return message diff --git a/tests/test_conversion.py b/tests/test_conversion.py index e2fd660..92e1cb2 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -4,6 +4,7 @@ import pyarrow as pa import pytest from google.protobuf.descriptor import FieldDescriptor +from google.protobuf.duration_pb2 import Duration from google.protobuf.empty_pb2 import Empty from google.protobuf.json_format import Parse from google.protobuf.message import Message @@ -40,8 +41,9 @@ SuperNestedExampleMessage, ) from protarrow_protos.example_pb2 import EmptyMessage, NestedEmptyMessage -from tests.random_generator import generate_messages, truncate_nanos +from tests.random_generator import generate_messages, truncate_messages, truncate_nanos +TEST_MESSAGE_COUNT = 5 MESSAGES = [ExampleMessage, NestedExampleMessage, SuperNestedExampleMessage] CONFIGS = [ ProtarrowConfig(), @@ -62,6 +64,10 @@ ProtarrowConfig(time_of_day_type=pa.time64("us")), ProtarrowConfig(time_of_day_type=pa.time32("ms")), ProtarrowConfig(time_of_day_type=pa.time32("s")), + ProtarrowConfig(duration_type=pa.duration("s")), + ProtarrowConfig(duration_type=pa.duration("ms")), + ProtarrowConfig(duration_type=pa.duration("us")), + ProtarrowConfig(duration_type=pa.duration("ns")), ProtarrowConfig(list_nullable=True), ProtarrowConfig(map_nullable=True), ProtarrowConfig(map_value_nullable=True), @@ -105,10 +111,7 @@ def test_with_random(message_type: Type[Message], config: ProtarrowConfig): source_messages = generate_messages(message_type, 10) table = messages_to_table(source_messages, message_type, config) messages_back = table_to_messages(table, message_type) - truncated_messages = [ - truncate_nanos(m, config.timestamp_type.unit, config.time_of_day_type.unit) - for m in source_messages - ] + truncated_messages = truncate_messages(source_messages, config) _check_messages_same(truncated_messages, messages_back) @@ -121,10 +124,7 @@ def test_with_random_not_aligned( source_messages = generate_messages(message_type, 3) table = messages_to_table(source_messages, message_type, config) messages_back = table_to_messages(table[index:], message_type) - truncated_messages = [ - truncate_nanos(m, config.timestamp_type.unit, config.time_of_day_type.unit) - for m in source_messages[index:] - ] + truncated_messages = truncate_messages(source_messages[index:], config) _check_messages_same(truncated_messages, messages_back) @@ -138,10 +138,7 @@ def test_with_sample_data(message_type: Type[Message], config: ProtarrowConfig): table = messages_to_table(source_messages, message_type, config) messages_back = table_to_messages(table, message_type) - truncated_messages = [ - truncate_nanos(m, config.timestamp_type.unit, config.time_of_day_type.unit) - for m in source_messages - ] + truncated_messages = truncate_messages(source_messages, config) _check_messages_same(truncated_messages, messages_back) @@ -281,27 +278,29 @@ def test_check_init_sorted(): def test_truncate_nanos(): assert truncate_nanos( - Timestamp(seconds=10, nanos=123456789), - "s", - "us", + Timestamp(seconds=10, nanos=123456789), "s", "us", "s" ) == Timestamp(seconds=10) assert truncate_nanos( - Timestamp(seconds=10, nanos=123456789), "ms", "us" + Timestamp(seconds=10, nanos=123456789), "ms", "us", "s" ) == Timestamp(seconds=10, nanos=123000000) assert truncate_nanos( - Timestamp(seconds=10, nanos=123456789), "us", "us" + Timestamp(seconds=10, nanos=123456789), "us", "us", "s" ) == Timestamp(seconds=10, nanos=123456000) assert truncate_nanos( - Timestamp(seconds=10, nanos=123456789), "ns", "us" + Timestamp(seconds=10, nanos=123456789), "ns", "us", "s" ) == Timestamp(seconds=10, nanos=123456789) assert truncate_nanos( - TimeOfDay(seconds=10, nanos=123456789), "ns", "us" + TimeOfDay(seconds=10, nanos=123456789), "ns", "us", "s" ) == TimeOfDay(seconds=10, nanos=123456000) + assert truncate_nanos( + Duration(seconds=10, nanos=123456789), "ns", "us", "us" + ) == Duration(seconds=10, nanos=123456000) + def test_truncate_nested(): assert truncate_nanos( @@ -312,6 +311,7 @@ def test_truncate_nested(): ), "us", "ms", + "ns", ) == ExampleMessage( timestamp_value=Timestamp(seconds=10, nanos=123_456_000), timestamp_string_map={"foo": Timestamp(seconds=10, nanos=123_456_000)}, @@ -330,6 +330,7 @@ def test_truncate_nested_nested(): ), "us", "ms", + "ns", ) == NestedExampleMessage( example_message=ExampleMessage( timestamp_value=Timestamp(seconds=10, nanos=123_456_000), @@ -433,7 +434,7 @@ def test_create_enum_converter_wrong_type(): @pytest.mark.parametrize("message_type", MESSAGES) @pytest.mark.parametrize("config", CONFIGS) def test_cast_empty(message_type: Type[Message], config: ProtarrowConfig): - table = pa.table({"nulls": pa.nulls(10, pa.null())}) + table = pa.table({"nulls": pa.nulls(TEST_MESSAGE_COUNT, pa.null())}) casted_table = cast_table(table, message_type, config) assert len(table) == len(casted_table) assert casted_table.schema == message_type_to_schema(message_type, config) @@ -442,7 +443,7 @@ def test_cast_empty(message_type: Type[Message], config: ProtarrowConfig): @pytest.mark.parametrize("message_type", MESSAGES) @pytest.mark.parametrize("config", CONFIGS) def test_cast_same(message_type: Type[Message], config: ProtarrowConfig): - source_messages = generate_messages(message_type, 10) + source_messages = generate_messages(message_type, TEST_MESSAGE_COUNT) table = messages_to_table(source_messages, message_type, config) casted_table = cast_table(table, message_type, config) assert table == casted_table @@ -451,10 +452,11 @@ def test_cast_same(message_type: Type[Message], config: ProtarrowConfig): @pytest.mark.parametrize("message_type", MESSAGES) @pytest.mark.parametrize("config", CONFIGS) def test_cast_same_view(message_type: Type[Message], config: ProtarrowConfig): - source_messages = generate_messages(message_type, 10) + source_messages = generate_messages(message_type, TEST_MESSAGE_COUNT) table = messages_to_table(source_messages, message_type, config) - view = table[5:] + assert 2 + 1 < TEST_MESSAGE_COUNT, "View too small" + view = table[2:] casted_view = cast_table(view, message_type, config) assert casted_view == view @@ -476,7 +478,7 @@ def test_can_cast_enum_to_dictionary_and_back( plain_config = ProtarrowConfig(enum_type=pa.string()) dict_config = ProtarrowConfig(enum_type=pa.dictionary(pa.int32(), pa.string())) - source_messages = generate_messages(message_type, 10) + source_messages = generate_messages(message_type, TEST_MESSAGE_COUNT) plain_table = messages_to_table(source_messages, message_type, plain_config) dict_table = messages_to_table(source_messages, message_type, dict_config) @@ -490,17 +492,14 @@ def test_can_cast_enum_to_dictionary_and_back( @pytest.mark.parametrize("message_type", MESSAGES) @pytest.mark.parametrize("config", CONFIGS) def test_extractor(message_type: Type[Message], config: ProtarrowConfig): - source_messages = [m for m in generate_messages(message_type, 10)] + source_messages = generate_messages(message_type, TEST_MESSAGE_COUNT) table = messages_to_table(source_messages, message_type, config) message_extractor = MessageExtractor(table.schema, message_type) messages_back = [ message_extractor.read_table_row(table, row) for row in range(len(table)) ] - truncated_messages = [ - truncate_nanos(m, config.timestamp_type.unit, config.time_of_day_type.unit) - for m in source_messages - ] + truncated_messages = truncate_messages(source_messages, config) _check_messages_same(truncated_messages, messages_back) @@ -802,3 +801,24 @@ def test_map_cast(): ].to_pylist() == array[1:].to_pylist() ) + + +@pytest.mark.parametrize("config", CONFIGS) +def test_duration(config): + messages = [ExampleMessage(duration_value=Duration(seconds=10, nanos=123456789))] + table = protarrow.messages_to_record_batch(messages, ExampleMessage, config) + assert isinstance(table, pa.RecordBatch) + messages_back = protarrow.record_batch_to_messages(table, ExampleMessage) + expected = truncate_messages(messages, config) + assert messages_back == expected + + +def test_duration_specific(): + messages = [ExampleMessage(duration_value=Duration(seconds=10, nanos=123456789))] + config = protarrow.ProtarrowConfig(duration_type=pa.duration("us")) + expected = truncate_messages(messages, config) + table = protarrow.messages_to_record_batch(messages, ExampleMessage, config) + assert table.schema.field("duration_value").type == pa.duration("us") + assert isinstance(table, pa.RecordBatch) + messages_back = protarrow.record_batch_to_messages(table, ExampleMessage) + assert messages_back == expected