From 3f06cb8f421d3aed27a17e6ef3ce5a1996296841 Mon Sep 17 00:00:00 2001 From: Kriskras99 Date: Mon, 21 Jul 2025 20:29:00 +0200 Subject: [PATCH 1/7] rfc --- rfc.md | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 rfc.md diff --git a/rfc.md b/rfc.md new file mode 100644 index 00000000..2c5fff1b --- /dev/null +++ b/rfc.md @@ -0,0 +1,88 @@ +# Possible implementations + +## Maintaining two separate implementations + +Pros: + - Easy to implement (just copy-paste the blocking implementation and start inserting `async`/`await`) + - Allows for optimal performance in both situations + - *Should* be able to share at least a part of the implementation + +Cons: + - Maintenance, any bug needs to be fixed in both implementations. Same goes for testing. + - Hard to onboard, new contributors will be confronted with a very large codebase (see [Good ol' copy-pasting](https://nullderef.com/blog/rust-async-sync/#good-ol-copy-pasting)) + - Adding new functionality means implementing it twice. + +## Implement in async, use `block_on` for sync implementation + +In this implementation, the core codebase is implemented asynchronously. A `blocking` module is provided which wraps +the async functions/types in `block_on` calls. Recreating the runtime on every call is very slow, so to make this work +it would involve spawning a thread for the runtime and using that to spawn the async functions. This is how `reqwest` +implements their async/sync code. + +Pros: + - Only need to maintain/test/upgrade one implementation + - Optimal performance for async code + +Cons: + - Degrades sync performance + - Need to pull in a runtime when the `blocking` feature is enabled (`reqwest` use `tokio` but something like `smoll` might make more sense) + +## Implement in async, use `maybe_async` to generate sync implementation + +[`maybe_async`](https://crates.io/crates/maybe-async) is a proc macro that removes the `.await` from the async code and uses it to generate sync code. + +Pros: + - Only need to maintain/test/upgrade one implementation + - Optimal performance for both async and sync code + +Cons: + - Crate breaks if both the `sync` and `async` features are enabled + +## Sans I/O + +Implement the parser as a state machine that can be driven by both async and sync code. This is how [`rc-zip`](https://lib.rs/crates/rc-zip) +is implemented. + +Pros: + - Only need to maintain/test/upgrade one implementation + - Optimal performance for both async and sync code + +Cons: + - Have to manually implement the state machines + - In the distant future [it's possible to use coroutines/generators](https://internals.rust-lang.org/t/using-coroutines-for-a-sans-io-parser/22968), but they're currently *very* unstable. + +## Do not provide an async implementation + +Pros: + - Easiest option, nothing has to change + +Cons: + - An async implementation is really nice for using Avro over the network + +# Serde + +One problem not mentioned yet, is that Serde does not have an async interface. This doesn't necessarily have to be a problem. +The current deserialize implementation also first decodes a `avro::Value` and then uses that to deserialize the Serde type (reverse for serialize). +The decoding to `avro::Value` can be made async, and then the serde part can be done in a sync way as it does not use any I/O. + +Some alternative options: +- [tokio-serde](https://docs.rs/tokio-serde/latest/tokio_serde/index.html) + - A wrapper around Serde that requires the user to split the input into frames containing one object. +- [destream](https://docs.rs/destream/0.9.0/destream/index.html) + - Async versions of the Serde traits, but not compatible with serde so lacks ecosystem support. + +# Best option? + +I'm currently leaning towards implementing Sans I/O. It provides an (almost) optimal implementation for both async and sync code. +It doesn't duplicate code (except the interfaces) and doesn't require pulling in any runtime (only parts of `futures`). + +Care needs to be taken that the state machines are kept small and understandable. + +The second-best option is probably using `block_on` in a separate thread. But that seems unnecessarily heavy. + +# References + +- [Blog post by the maintainer of `RSpotify` who tried multiple of the above options](https://nullderef.com/blog/rust-async-sync/) +- [A discussion about Sans I/O](https://sdr-podcast.com/episodes/sans-io/) +- [A explanation of Sans I/O by the author of `rc-zip`](https://fasterthanli.me/articles/the-case-for-sans-io) + - The blog post is currently not freely available, but the video (which has the exact same content) is freely available From fb9f14735c8e186ce3e6ac70e8ce715605ea289b Mon Sep 17 00:00:00 2001 From: Kriskras99 Date: Tue, 22 Jul 2025 14:37:09 +0200 Subject: [PATCH 2/7] Update rfc.md Co-authored-by: Martin Grigorov --- rfc.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rfc.md b/rfc.md index 2c5fff1b..a0a2aede 100644 --- a/rfc.md +++ b/rfc.md @@ -85,4 +85,4 @@ The second-best option is probably using `block_on` in a separate thread. But th - [Blog post by the maintainer of `RSpotify` who tried multiple of the above options](https://nullderef.com/blog/rust-async-sync/) - [A discussion about Sans I/O](https://sdr-podcast.com/episodes/sans-io/) - [A explanation of Sans I/O by the author of `rc-zip`](https://fasterthanli.me/articles/the-case-for-sans-io) - - The blog post is currently not freely available, but the video (which has the exact same content) is freely available + - The blog post is currently not freely available, but the [video](https://www.youtube.com/watch?v=RYHYiXMJdZI) (which has the exact same content) is freely available From c1d800911cbf2c7b0874bb1720003e04ff37266d Mon Sep 17 00:00:00 2001 From: Kriskras99 Date: Wed, 23 Jul 2025 16:29:25 +0200 Subject: [PATCH 3/7] Update rfc.md --- rfc.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rfc.md b/rfc.md index a0a2aede..38741cc8 100644 --- a/rfc.md +++ b/rfc.md @@ -29,14 +29,15 @@ Cons: ## Implement in async, use `maybe_async` to generate sync implementation -[`maybe_async`](https://crates.io/crates/maybe-async) is a proc macro that removes the `.await` from the async code and uses it to generate sync code. +[`maybe_async`](https://crates.io/crates/maybe-async) is a proc macro that removes the `.await` from the async code and uses it to generate sync code. [`synca`](https://docs.rs/synca/latest/synca/) is another option where both sync and async code can coexist. Pros: - Only need to maintain/test/upgrade one implementation - Optimal performance for both async and sync code Cons: - - Crate breaks if both the `sync` and `async` features are enabled + - Crate breaks if both the `sync` and `async` features are enabled (only for `maybe_async`) + - `synca` hasn't seen an update in more than a year, but seems to be feature complete ## Sans I/O From 63a648e29a5fb49dca3647d0b526b68d5aeccceb Mon Sep 17 00:00:00 2001 From: Kriskras99 Date: Wed, 23 Jul 2025 16:31:44 +0200 Subject: [PATCH 4/7] Update rfc.md --- rfc.md | 1 + 1 file changed, 1 insertion(+) diff --git a/rfc.md b/rfc.md index 38741cc8..1c2c88dc 100644 --- a/rfc.md +++ b/rfc.md @@ -51,6 +51,7 @@ Pros: Cons: - Have to manually implement the state machines - In the distant future [it's possible to use coroutines/generators](https://internals.rust-lang.org/t/using-coroutines-for-a-sans-io-parser/22968), but they're currently *very* unstable. + - You can use async functions to generate the state machines for you, [according to this blogpost](https://jeffmcbride.net/blog/2025/05/16/rust-async-functions-as-state-machines/) ## Do not provide an async implementation From 6c2ddf1bf36bd2602d57e87e8c115b7616d2f8b1 Mon Sep 17 00:00:00 2001 From: kriskras99 Date: Fri, 25 Jul 2025 17:45:03 +0000 Subject: [PATCH 5/7] State machine example --- Cargo.lock | 171 ++- avro/Cargo.toml | 6 + avro/src/bigdecimal.rs | 29 +- avro/src/decode.rs | 875 ------------- avro/src/error.rs | 10 + avro/src/lib.rs | 8 +- avro/src/reader.rs | 602 ++------- avro/src/state_machines/mod.rs | 1 + avro/src/state_machines/reading/async_impl.rs | 302 +++++ avro/src/state_machines/reading/block.rs | 110 ++ avro/src/state_machines/reading/bytes.rs | 71 + avro/src/state_machines/reading/codec.rs | 180 +++ avro/src/state_machines/reading/commands.rs | 646 +++++++++ avro/src/state_machines/reading/datum.rs | 80 ++ avro/src/state_machines/reading/error.rs | 16 + avro/src/state_machines/reading/mod.rs | 1151 +++++++++++++++++ .../reading/object_container_file.rs | 318 +++++ avro/src/state_machines/reading/sync.rs | 351 +++++ avro/src/state_machines/reading/union.rs | 78 ++ avro/src/types.rs | 1 + avro/src/util.rs | 101 +- avro_derive/tests/derive.proptest-regressions | 12 + rustfmt.toml | 2 + 23 files changed, 3674 insertions(+), 1447 deletions(-) delete mode 100644 avro/src/decode.rs create mode 100644 avro/src/state_machines/mod.rs create mode 100644 avro/src/state_machines/reading/async_impl.rs create mode 100644 avro/src/state_machines/reading/block.rs create mode 100644 avro/src/state_machines/reading/bytes.rs create mode 100644 avro/src/state_machines/reading/codec.rs create mode 100644 avro/src/state_machines/reading/commands.rs create mode 100644 avro/src/state_machines/reading/datum.rs create mode 100644 avro/src/state_machines/reading/error.rs create mode 100644 avro/src/state_machines/reading/mod.rs create mode 100644 avro/src/state_machines/reading/object_container_file.rs create mode 100644 avro/src/state_machines/reading/sync.rs create mode 100644 avro/src/state_machines/reading/union.rs create mode 100644 avro_derive/tests/derive.proptest-regressions create mode 100644 rustfmt.toml diff --git a/Cargo.lock b/Cargo.lock index b97fd78d..8da21018 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aho-corasick" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" dependencies = [ "memchr", ] @@ -51,18 +51,21 @@ dependencies = [ "anyhow", "apache-avro-derive", "apache-avro-test-helper", + "async-stream", "bigdecimal", "bon", "bzip2", "crc32fast", "criterion", "digest", + "futures", "hex-literal", "liblzma", "log", "md-5", "miniz_oxide", "num-bigint", + "oval", "paste", "pretty_assertions", "quad-rand", @@ -107,6 +110,28 @@ dependencies = [ "log", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -198,6 +223,12 @@ version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + [[package]] name = "bzip2" version = "0.6.1" @@ -215,9 +246,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.43" +version = "1.2.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "739eb0f94557554b3ca9a86d2d37bebd49c5e6d0c1d2bda35ba5bdac830befc2" +checksum = "37521ac7aabe3d13122dc382493e20c9416f299d2ccd5b3a5340a2570cdeb0f3" dependencies = [ "find-msvc-tools", "jobserver", @@ -260,18 +291,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.50" +version = "4.5.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2cfd7bf8a6017ddaa4e32ffe7403d547790db06bd171c1c53926faab501623" +checksum = "4c26d721170e0295f191a69bd9a1f93efcdb0aff38684b61ab5750468972e5f5" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.50" +version = "4.5.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4c05b9e80c5ccd3a7ef080ad7b6ba7d6fc00a985b8b157197075677c82c7a0" +checksum = "75835f0c7bf681bfd05abe44e965760fea999a5286c6eb2d59883634fd02011a" dependencies = [ "anstyle", "clap_lex", @@ -439,9 +470,9 @@ dependencies = [ [[package]] name = "dtor" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e58a0764cddb55ab28955347b45be00ade43d4d6f3ba4bf3dc354e4ec9432934" +checksum = "404d02eeb088a82cfd873006cb713fe411306c7d182c344905e101fb1167d301" dependencies = [ "dtor-proc-macro", ] @@ -495,6 +526,95 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.9" @@ -733,12 +853,33 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "oval" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135cef32720c6746450d910890b0b69bcba2bbf6f85c9f4583df13fe415de828" +dependencies = [ + "bytes", +] + [[package]] name = "paste" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.32" @@ -1036,6 +1177,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + [[package]] name = "snap" version = "1.1.1" @@ -1121,9 +1268,9 @@ checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" [[package]] name = "unicode-ident" -version = "1.0.20" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "462eeb75aeb73aea900253ce739c8e18a67423fadf006037cd3ff27e82748a06" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" [[package]] name = "uuid" diff --git a/avro/Cargo.toml b/avro/Cargo.toml index 975d2f26..e22a85ad 100644 --- a/avro/Cargo.toml +++ b/avro/Cargo.toml @@ -29,11 +29,14 @@ categories.workspace = true documentation.workspace = true [features] +default = ["futures", "sync"] bzip = ["dep:bzip2"] derive = ["dep:apache-avro-derive"] snappy = ["dep:crc32fast", "dep:snap"] xz = ["dep:liblzma"] zstandard = ["dep:zstd"] +futures = [] +sync = [] [lib] # disable benchmarks to allow passing criterion arguments to `cargo bench` @@ -73,6 +76,9 @@ thiserror = { default-features = false, version = "2.0.17" } uuid = { default-features = false, version = "1.18.1", features = ["serde", "std"] } liblzma = { default-features = false, version = "0.4.5", optional = true } zstd = { default-features = false, version = "0.13.3", optional = true } +oval = { version = "2.0.0", features = ["bytes"] } +futures = "0.3.31" +async-stream = "0.3.6" [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/avro/src/bigdecimal.rs b/avro/src/bigdecimal.rs index 6022f0dc..ee337356 100644 --- a/avro/src/bigdecimal.rs +++ b/avro/src/bigdecimal.rs @@ -17,10 +17,9 @@ use crate::{ AvroResult, - decode::{decode_len, decode_long}, encode::{encode_bytes, encode_long}, error::Details, - types::Value, + util::{decode_len_simple, decode_variable}, }; pub use bigdecimal::BigDecimal; use num_bigint::BigInt; @@ -47,10 +46,12 @@ pub(crate) fn serialize_big_decimal(decimal: &BigDecimal) -> AvroResult> Ok(final_buffer) } -pub(crate) fn deserialize_big_decimal(bytes: &Vec) -> AvroResult { - let mut bytes: &[u8] = bytes.as_slice(); - let mut big_decimal_buffer = match decode_len(&mut bytes) { - Ok(size) => vec![0u8; size], +pub(crate) fn deserialize_big_decimal(mut bytes: &[u8]) -> AvroResult { + let mut big_decimal_buffer = match decode_len_simple(bytes) { + Ok((size, bytes_read)) => { + bytes = &bytes[bytes_read..]; + vec![0u8; size] + } Err(err) => return Err(Details::BigDecimalLen(Box::new(err)).into()), }; @@ -58,8 +59,8 @@ pub(crate) fn deserialize_big_decimal(bytes: &Vec) -> AvroResult .read_exact(&mut big_decimal_buffer[..]) .map_err(Details::ReadDouble)?; - match decode_long(&mut bytes) { - Ok(Value::Long(scale_value)) => { + match decode_variable(bytes) { + Ok(Some((scale_value, _))) => { let big_int: BigInt = BigInt::from_signed_bytes_be(&big_decimal_buffer); let decimal = BigDecimal::new(big_int, scale_value); Ok(decimal) @@ -71,7 +72,11 @@ pub(crate) fn deserialize_big_decimal(bytes: &Vec) -> AvroResult #[cfg(test)] mod tests { use super::*; - use crate::{Codec, Reader, Schema, Writer, error::Error, types::Record}; + use crate::{ + Codec, Reader, Schema, Writer, + error::Error, + types::{Record, Value}, + }; use apache_avro_test_helper::TestResult; use bigdecimal::{One, Zero}; use pretty_assertions::assert_eq; @@ -92,7 +97,8 @@ mod tests { let buffer: Vec = serialize_big_decimal(¤t)?; let mut as_slice = buffer.as_slice(); - decode_long(&mut as_slice)?; + let (_, bytes_read) = decode_variable(as_slice)?.unwrap(); + as_slice = &as_slice[bytes_read..]; let mut result: Vec = Vec::new(); result.extend_from_slice(as_slice); @@ -109,7 +115,8 @@ mod tests { let buffer: Vec = serialize_big_decimal(&BigDecimal::zero())?; let mut as_slice = buffer.as_slice(); - decode_long(&mut as_slice)?; + let (_, bytes_read) = decode_variable(as_slice)?.unwrap(); + as_slice = &as_slice[bytes_read..]; let mut result: Vec = Vec::new(); result.extend_from_slice(as_slice); diff --git a/avro/src/decode.rs b/avro/src/decode.rs deleted file mode 100644 index 78fefbd9..00000000 --- a/avro/src/decode.rs +++ /dev/null @@ -1,875 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::{ - AvroResult, Error, - bigdecimal::deserialize_big_decimal, - decimal::Decimal, - duration::Duration, - error::Details, - schema::{ - DecimalSchema, EnumSchema, FixedSchema, Name, Namespace, RecordSchema, ResolvedSchema, - Schema, - }, - types::Value, - util::{safe_len, zag_i32, zag_i64}, -}; -use std::{ - borrow::Borrow, - collections::HashMap, - io::{ErrorKind, Read}, -}; -use uuid::Uuid; - -#[inline] -pub(crate) fn decode_long(reader: &mut R) -> AvroResult { - zag_i64(reader).map(Value::Long) -} - -#[inline] -fn decode_int(reader: &mut R) -> AvroResult { - zag_i32(reader).map(Value::Int) -} - -#[inline] -pub(crate) fn decode_len(reader: &mut R) -> AvroResult { - let len = zag_i64(reader)?; - safe_len(usize::try_from(len).map_err(|e| Details::ConvertI64ToUsize(e, len))?) -} - -/// Decode the length of a sequence. -/// -/// Maps and arrays are 0-terminated, 0i64 is also encoded as 0 in Avro reading a length of 0 means -/// the end of the map or array. -fn decode_seq_len(reader: &mut R) -> AvroResult { - let raw_len = zag_i64(reader)?; - safe_len( - usize::try_from(match raw_len.cmp(&0) { - std::cmp::Ordering::Equal => return Ok(0), - std::cmp::Ordering::Less => { - let _size = zag_i64(reader)?; - raw_len.checked_neg().ok_or(Details::IntegerOverflow)? - } - std::cmp::Ordering::Greater => raw_len, - }) - .map_err(|e| Details::ConvertI64ToUsize(e, raw_len))?, - ) -} - -/// Decode a `Value` from avro format given its `Schema`. -pub fn decode(schema: &Schema, reader: &mut R) -> AvroResult { - let rs = ResolvedSchema::try_from(schema)?; - decode_internal(schema, rs.get_names(), &None, reader) -} - -pub(crate) fn decode_internal>( - schema: &Schema, - names: &HashMap, - enclosing_namespace: &Namespace, - reader: &mut R, -) -> AvroResult { - match *schema { - Schema::Null => Ok(Value::Null), - Schema::Boolean => { - let mut buf = [0u8; 1]; - match reader.read_exact(&mut buf[..]) { - Ok(_) => match buf[0] { - 0u8 => Ok(Value::Boolean(false)), - 1u8 => Ok(Value::Boolean(true)), - _ => Err(Details::BoolValue(buf[0]).into()), - }, - Err(io_err) => { - if let ErrorKind::UnexpectedEof = io_err.kind() { - Ok(Value::Null) - } else { - Err(Details::ReadBoolean(io_err).into()) - } - } - } - } - Schema::Decimal(DecimalSchema { ref inner, .. }) => match &**inner { - Schema::Fixed { .. } => { - match decode_internal(inner, names, enclosing_namespace, reader)? { - Value::Fixed(_, bytes) => Ok(Value::Decimal(Decimal::from(bytes))), - value => Err(Details::FixedValue(value).into()), - } - } - Schema::Bytes => match decode_internal(inner, names, enclosing_namespace, reader)? { - Value::Bytes(bytes) => Ok(Value::Decimal(Decimal::from(bytes))), - value => Err(Details::BytesValue(value).into()), - }, - schema => Err(Details::ResolveDecimalSchema(schema.into()).into()), - }, - Schema::BigDecimal => { - match decode_internal(&Schema::Bytes, names, enclosing_namespace, reader)? { - Value::Bytes(bytes) => deserialize_big_decimal(&bytes).map(Value::BigDecimal), - value => Err(Details::BytesValue(value).into()), - } - } - Schema::Uuid => { - let Value::Bytes(bytes) = - decode_internal(&Schema::Bytes, names, enclosing_namespace, reader)? - else { - // Calling decode_internal with Schema::Bytes can only return a Value::Bytes or an error - unreachable!(); - }; - - let uuid = if bytes.len() == 16 { - Uuid::from_slice(&bytes).map_err(Details::ConvertSliceToUuid)? - } else { - let string = std::str::from_utf8(&bytes).map_err(Details::ConvertToUtf8Error)?; - Uuid::parse_str(string).map_err(Details::ConvertStrToUuid)? - }; - Ok(Value::Uuid(uuid)) - } - Schema::Int => decode_int(reader), - Schema::Date => zag_i32(reader).map(Value::Date), - Schema::TimeMillis => zag_i32(reader).map(Value::TimeMillis), - Schema::Long => decode_long(reader), - Schema::TimeMicros => zag_i64(reader).map(Value::TimeMicros), - Schema::TimestampMillis => zag_i64(reader).map(Value::TimestampMillis), - Schema::TimestampMicros => zag_i64(reader).map(Value::TimestampMicros), - Schema::TimestampNanos => zag_i64(reader).map(Value::TimestampNanos), - Schema::LocalTimestampMillis => zag_i64(reader).map(Value::LocalTimestampMillis), - Schema::LocalTimestampMicros => zag_i64(reader).map(Value::LocalTimestampMicros), - Schema::LocalTimestampNanos => zag_i64(reader).map(Value::LocalTimestampNanos), - Schema::Duration => { - let mut buf = [0u8; 12]; - reader.read_exact(&mut buf).map_err(Details::ReadDuration)?; - Ok(Value::Duration(Duration::from(buf))) - } - Schema::Float => { - let mut buf = [0u8; std::mem::size_of::()]; - reader - .read_exact(&mut buf[..]) - .map_err(Details::ReadFloat)?; - Ok(Value::Float(f32::from_le_bytes(buf))) - } - Schema::Double => { - let mut buf = [0u8; std::mem::size_of::()]; - reader - .read_exact(&mut buf[..]) - .map_err(Details::ReadDouble)?; - Ok(Value::Double(f64::from_le_bytes(buf))) - } - Schema::Bytes => { - let len = decode_len(reader)?; - let mut buf = vec![0u8; len]; - reader.read_exact(&mut buf).map_err(Details::ReadBytes)?; - Ok(Value::Bytes(buf)) - } - Schema::String => { - let len = decode_len(reader)?; - let mut buf = vec![0u8; len]; - match reader.read_exact(&mut buf) { - Ok(_) => Ok(Value::String( - String::from_utf8(buf).map_err(Details::ConvertToUtf8)?, - )), - Err(io_err) => { - if let ErrorKind::UnexpectedEof = io_err.kind() { - Ok(Value::Null) - } else { - Err(Details::ReadString(io_err).into()) - } - } - } - } - Schema::Fixed(FixedSchema { size, .. }) => { - let mut buf = vec![0u8; size]; - reader - .read_exact(&mut buf) - .map_err(|e| Details::ReadFixed(e, size))?; - Ok(Value::Fixed(size, buf)) - } - Schema::Array(ref inner) => { - let mut items = Vec::new(); - - loop { - let len = decode_seq_len(reader)?; - if len == 0 { - break; - } - - items.reserve(len); - for _ in 0..len { - items.push(decode_internal( - &inner.items, - names, - enclosing_namespace, - reader, - )?); - } - } - - Ok(Value::Array(items)) - } - Schema::Map(ref inner) => { - let mut items = HashMap::new(); - - loop { - let len = decode_seq_len(reader)?; - if len == 0 { - break; - } - - items.reserve(len); - for _ in 0..len { - match decode_internal(&Schema::String, names, enclosing_namespace, reader)? { - Value::String(key) => { - let value = - decode_internal(&inner.types, names, enclosing_namespace, reader)?; - items.insert(key, value); - } - value => return Err(Details::MapKeyType(value.into()).into()), - } - } - } - - Ok(Value::Map(items)) - } - Schema::Union(ref inner) => match zag_i64(reader).map_err(Error::into_details) { - Ok(index) => { - let variants = inner.variants(); - let variant = variants - .get(usize::try_from(index).map_err(|e| Details::ConvertI64ToUsize(e, index))?) - .ok_or(Details::GetUnionVariant { - index, - num_variants: variants.len(), - })?; - let value = decode_internal(variant, names, enclosing_namespace, reader)?; - Ok(Value::Union(index as u32, Box::new(value))) - } - Err(Details::ReadVariableIntegerBytes(io_err)) => { - if let ErrorKind::UnexpectedEof = io_err.kind() { - Ok(Value::Union(0, Box::new(Value::Null))) - } else { - Err(Details::ReadVariableIntegerBytes(io_err).into()) - } - } - Err(io_err) => Err(Error::new(io_err)), - }, - Schema::Record(RecordSchema { - ref name, - ref fields, - .. - }) => { - let fully_qualified_name = name.fully_qualified_name(enclosing_namespace); - // Benchmarks indicate ~10% improvement using this method. - let mut items = Vec::with_capacity(fields.len()); - for field in fields { - // TODO: This clone is also expensive. See if we can do away with it... - items.push(( - field.name.clone(), - decode_internal( - &field.schema, - names, - &fully_qualified_name.namespace, - reader, - )?, - )); - } - Ok(Value::Record(items)) - } - Schema::Enum(EnumSchema { ref symbols, .. }) => { - Ok(if let Value::Int(raw_index) = decode_int(reader)? { - let index = usize::try_from(raw_index) - .map_err(|e| Details::ConvertI32ToUsize(e, raw_index))?; - if (0..symbols.len()).contains(&index) { - let symbol = symbols[index].clone(); - Value::Enum(raw_index as u32, symbol) - } else { - return Err(Details::GetEnumValue { - index, - nsymbols: symbols.len(), - } - .into()); - } - } else { - return Err(Details::GetEnumUnknownIndexValue.into()); - }) - } - Schema::Ref { ref name } => { - let fully_qualified_name = name.fully_qualified_name(enclosing_namespace); - if let Some(resolved) = names.get(&fully_qualified_name) { - decode_internal( - resolved.borrow(), - names, - &fully_qualified_name.namespace, - reader, - ) - } else { - Err(Details::SchemaResolutionError(fully_qualified_name).into()) - } - } - } -} - -#[cfg(test)] -#[allow(clippy::expect_fun_call)] -mod tests { - use crate::{ - Decimal, - decode::decode, - encode::{encode, tests::success}, - schema::{DecimalSchema, FixedSchema, Schema}, - types::{ - Value, - Value::{Array, Int, Map}, - }, - }; - use apache_avro_test_helper::TestResult; - use pretty_assertions::assert_eq; - use std::collections::HashMap; - use uuid::Uuid; - - #[test] - fn test_decode_array_without_size() -> TestResult { - let mut input: &[u8] = &[6, 2, 4, 6, 0]; - let result = decode(&Schema::array(Schema::Int), &mut input); - assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result?); - - Ok(()) - } - - #[test] - fn test_decode_array_with_size() -> TestResult { - let mut input: &[u8] = &[5, 6, 2, 4, 6, 0]; - let result = decode(&Schema::array(Schema::Int), &mut input); - assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result?); - - Ok(()) - } - - #[test] - fn test_decode_map_without_size() -> TestResult { - let mut input: &[u8] = &[0x02, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; - let result = decode(&Schema::map(Schema::Int), &mut input); - let mut expected = HashMap::new(); - expected.insert(String::from("test"), Int(1)); - assert_eq!(Map(expected), result?); - - Ok(()) - } - - #[test] - fn test_decode_map_with_size() -> TestResult { - let mut input: &[u8] = &[0x01, 0x0C, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; - let result = decode(&Schema::map(Schema::Int), &mut input); - let mut expected = HashMap::new(); - expected.insert(String::from("test"), Int(1)); - assert_eq!(Map(expected), result?); - - Ok(()) - } - - #[test] - fn test_negative_decimal_value() -> TestResult { - use crate::{encode::encode, schema::Name}; - use num_bigint::ToBigInt; - let inner = Box::new(Schema::Fixed( - FixedSchema::builder() - .name(Name::new("decimal")?) - .size(2) - .build(), - )); - let schema = Schema::Decimal(DecimalSchema { - inner, - precision: 4, - scale: 2, - }); - let bigint = (-423).to_bigint().unwrap(); - let value = Value::Decimal(Decimal::from(bigint.to_signed_bytes_be())); - - let mut buffer = Vec::new(); - encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - - let mut bytes = &buffer[..]; - let result = decode(&schema, &mut bytes)?; - assert_eq!(result, value); - - Ok(()) - } - - #[test] - fn test_decode_decimal_with_bigger_than_necessary_size() -> TestResult { - use crate::{encode::encode, schema::Name}; - use num_bigint::ToBigInt; - let inner = Box::new(Schema::Fixed(FixedSchema { - size: 13, - name: Name::new("decimal")?, - aliases: None, - doc: None, - default: None, - attributes: Default::default(), - })); - let schema = Schema::Decimal(DecimalSchema { - inner, - precision: 4, - scale: 2, - }); - let value = Value::Decimal(Decimal::from( - ((-423).to_bigint().unwrap()).to_signed_bytes_be(), - )); - let mut buffer = Vec::::new(); - - encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - let mut bytes: &[u8] = &buffer[..]; - let result = decode(&schema, &mut bytes)?; - assert_eq!(result, value); - - Ok(()) - } - - #[test] - fn test_avro_3448_recursive_definition_decode_union() -> TestResult { - // if encoding fails in this test check the corresponding test in encode - let schema = Schema::parse_str( - r#" - { - "type":"record", - "name":"TestStruct", - "fields": [ - { - "name":"a", - "type":[ "null", { - "type":"record", - "name": "Inner", - "fields": [ { - "name":"z", - "type":"int" - }] - }] - }, - { - "name":"b", - "type":"Inner" - } - ] - }"#, - )?; - - let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); - let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); - let outer_value1 = Value::Record(vec![ - ("a".into(), Value::Union(1, Box::new(inner_value1))), - ("b".into(), inner_value2.clone()), - ]); - let mut buf = Vec::new(); - encode(&outer_value1, &schema, &mut buf).expect(&success(&outer_value1, &schema)); - assert!(!buf.is_empty()); - let mut bytes = &buf[..]; - assert_eq!( - outer_value1, - decode(&schema, &mut bytes).expect(&format!( - "Failed to decode using recursive definitions with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - let outer_value2 = Value::Record(vec![ - ("a".into(), Value::Union(0, Box::new(Value::Null))), - ("b".into(), inner_value2), - ]); - encode(&outer_value2, &schema, &mut buf).expect(&success(&outer_value2, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_value2, - decode(&schema, &mut bytes).expect(&format!( - "Failed to decode using recursive definitions with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn test_avro_3448_recursive_definition_decode_array() -> TestResult { - let schema = Schema::parse_str( - r#" - { - "type":"record", - "name":"TestStruct", - "fields": [ - { - "name":"a", - "type":{ - "type":"array", - "items": { - "type":"record", - "name": "Inner", - "fields": [ { - "name":"z", - "type":"int" - }] - } - } - }, - { - "name":"b", - "type": "Inner" - } - ] - }"#, - )?; - - let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); - let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); - let outer_value = Value::Record(vec![ - ("a".into(), Value::Array(vec![inner_value1])), - ("b".into(), inner_value2), - ]); - let mut buf = Vec::new(); - encode(&outer_value, &schema, &mut buf).expect(&success(&outer_value, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_value, - decode(&schema, &mut bytes).expect(&format!( - "Failed to decode using recursive definitions with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn test_avro_3448_recursive_definition_decode_map() -> TestResult { - let schema = Schema::parse_str( - r#" - { - "type":"record", - "name":"TestStruct", - "fields": [ - { - "name":"a", - "type":{ - "type":"map", - "values": { - "type":"record", - "name": "Inner", - "fields": [ { - "name":"z", - "type":"int" - }] - } - } - }, - { - "name":"b", - "type": "Inner" - } - ] - }"#, - )?; - - let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); - let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); - let outer_value = Value::Record(vec![ - ( - "a".into(), - Value::Map(vec![("akey".into(), inner_value1)].into_iter().collect()), - ), - ("b".into(), inner_value2), - ]); - let mut buf = Vec::new(); - encode(&outer_value, &schema, &mut buf).expect(&success(&outer_value, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_value, - decode(&schema, &mut bytes).expect(&format!( - "Failed to decode using recursive definitions with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn test_avro_3448_proper_multi_level_decoding_middle_namespace() -> TestResult { - // if encoding fails in this test check the corresponding test in encode - let schema = r#" - { - "name": "record_name", - "namespace": "space", - "type": "record", - "fields": [ - { - "name": "outer_field_1", - "type": [ - "null", - { - "type": "record", - "name": "middle_record_name", - "namespace":"middle_namespace", - "fields":[ - { - "name":"middle_field_1", - "type":[ - "null", - { - "type":"record", - "name":"inner_record_name", - "fields":[ - { - "name":"inner_field_1", - "type":"double" - } - ] - } - ] - } - ] - } - ] - }, - { - "name": "outer_field_2", - "type" : "middle_namespace.inner_record_name" - } - ] - } - "#; - let schema = Schema::parse_str(schema)?; - let inner_record = Value::Record(vec![("inner_field_1".into(), Value::Double(5.4))]); - let middle_record_variation_1 = Value::Record(vec![( - "middle_field_1".into(), - Value::Union(0, Box::new(Value::Null)), - )]); - let middle_record_variation_2 = Value::Record(vec![( - "middle_field_1".into(), - Value::Union(1, Box::new(inner_record.clone())), - )]); - let outer_record_variation_1 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(0, Box::new(Value::Null)), - ), - ("outer_field_2".into(), inner_record.clone()), - ]); - let outer_record_variation_2 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(1, Box::new(middle_record_variation_1)), - ), - ("outer_field_2".into(), inner_record.clone()), - ]); - let outer_record_variation_3 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(1, Box::new(middle_record_variation_2)), - ), - ("outer_field_2".into(), inner_record), - ]); - - let mut buf = Vec::new(); - encode(&outer_record_variation_1, &schema, &mut buf) - .expect(&success(&outer_record_variation_1, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_1, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - encode(&outer_record_variation_2, &schema, &mut buf) - .expect(&success(&outer_record_variation_2, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_2, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - encode(&outer_record_variation_3, &schema, &mut buf) - .expect(&success(&outer_record_variation_3, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_3, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn test_avro_3448_proper_multi_level_decoding_inner_namespace() -> TestResult { - // if encoding fails in this test check the corresponding test in encode - let schema = r#" - { - "name": "record_name", - "namespace": "space", - "type": "record", - "fields": [ - { - "name": "outer_field_1", - "type": [ - "null", - { - "type": "record", - "name": "middle_record_name", - "namespace":"middle_namespace", - "fields":[ - { - "name":"middle_field_1", - "type":[ - "null", - { - "type":"record", - "name":"inner_record_name", - "namespace":"inner_namespace", - "fields":[ - { - "name":"inner_field_1", - "type":"double" - } - ] - } - ] - } - ] - } - ] - }, - { - "name": "outer_field_2", - "type" : "inner_namespace.inner_record_name" - } - ] - } - "#; - let schema = Schema::parse_str(schema)?; - let inner_record = Value::Record(vec![("inner_field_1".into(), Value::Double(5.4))]); - let middle_record_variation_1 = Value::Record(vec![( - "middle_field_1".into(), - Value::Union(0, Box::new(Value::Null)), - )]); - let middle_record_variation_2 = Value::Record(vec![( - "middle_field_1".into(), - Value::Union(1, Box::new(inner_record.clone())), - )]); - let outer_record_variation_1 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(0, Box::new(Value::Null)), - ), - ("outer_field_2".into(), inner_record.clone()), - ]); - let outer_record_variation_2 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(1, Box::new(middle_record_variation_1)), - ), - ("outer_field_2".into(), inner_record.clone()), - ]); - let outer_record_variation_3 = Value::Record(vec![ - ( - "outer_field_1".into(), - Value::Union(1, Box::new(middle_record_variation_2)), - ), - ("outer_field_2".into(), inner_record), - ]); - - let mut buf = Vec::new(); - encode(&outer_record_variation_1, &schema, &mut buf) - .expect(&success(&outer_record_variation_1, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_1, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - encode(&outer_record_variation_2, &schema, &mut buf) - .expect(&success(&outer_record_variation_2, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_2, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - let mut buf = Vec::new(); - encode(&outer_record_variation_3, &schema, &mut buf) - .expect(&success(&outer_record_variation_3, &schema)); - let mut bytes = &buf[..]; - assert_eq!( - outer_record_variation_3, - decode(&schema, &mut bytes).expect(&format!( - "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", - &schema - )) - ); - - Ok(()) - } - - #[test] - fn avro_3926_encode_decode_uuid_to_string() -> TestResult { - use crate::encode::encode; - - let schema = Schema::String; - let value = Value::Uuid(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000")?); - - let mut buffer = Vec::new(); - encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - - let result = decode(&Schema::Uuid, &mut &buffer[..])?; - assert_eq!(result, value); - - Ok(()) - } - - #[test] - fn avro_3926_encode_decode_uuid_to_fixed() -> TestResult { - use crate::encode::encode; - - let schema = Schema::Fixed(FixedSchema { - size: 16, - name: "uuid".into(), - aliases: None, - doc: None, - default: None, - attributes: Default::default(), - }); - let value = Value::Uuid(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000")?); - - let mut buffer = Vec::new(); - encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - - let result = decode(&Schema::Uuid, &mut &buffer[..])?; - assert_eq!(result, value); - - Ok(()) - } -} diff --git a/avro/src/error.rs b/avro/src/error.rs index 95aeb2b9..2163d050 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -17,6 +17,7 @@ use crate::{ schema::{Name, Schema, SchemaKind, UnionSchema}, + state_machines::reading::error::ValueFromTapeError, types::{Value, ValueKind}, }; use std::{error::Error as _, fmt}; @@ -56,6 +57,12 @@ impl From
for Error { } } +impl From for Error { + fn from(value: ValueFromTapeError) -> Self { + Self::new(value.into()) + } +} + impl serde::ser::Error for Error { fn custom(msg: T) -> Self { Self::new(
::custom(msg)) @@ -576,6 +583,9 @@ pub enum Details { #[error("Cannot convert a slice to Uuid: {0}")] UuidFromSlice(#[source] uuid::Error), + + #[error(transparent)] + ValueFromTapeError(#[from] ValueFromTapeError), } #[derive(thiserror::Error, PartialEq)] diff --git a/avro/src/lib.rs b/avro/src/lib.rs index 95dffa7d..baea6e10 100644 --- a/avro/src/lib.rs +++ b/avro/src/lib.rs @@ -947,7 +947,6 @@ mod bytes; mod codec; mod de; mod decimal; -mod decode; mod duration; mod encode; mod reader; @@ -961,6 +960,7 @@ pub mod rabin; pub mod schema; pub mod schema_compatibility; pub mod schema_equality; +pub mod state_machines; pub mod types; pub mod util; pub mod validator; @@ -1038,6 +1038,12 @@ pub fn set_serde_human_readable(human_readable: bool) -> bool { util::set_serde_human_readable(human_readable) } +/// Async versions of the types and functions. +pub mod not_sync { + #[doc(inline)] + pub use crate::reader::async_reader::*; +} + #[cfg(test)] mod tests { use crate::{ diff --git a/avro/src/reader.rs b/avro/src/reader.rs index ec7412cc..5590ddd3 100644 --- a/avro/src/reader.rs +++ b/avro/src/reader.rs @@ -16,504 +16,49 @@ // under the License. //! Logic handling reading from Avro format at user level. + +pub use crate::state_machines::reading::sync::{ + Reader, from_avro_datum, from_avro_datum_reader_schemata, from_avro_datum_schemata, +}; use crate::{ - AvroResult, Codec, Error, - decode::{decode, decode_internal}, + AvroResult, error::Details, from_value, headers::{HeaderBuilder, RabinFingerprintHeader}, - schema::{ - AvroSchema, Names, ResolvedOwnedSchema, ResolvedSchema, Schema, resolve_names, - resolve_names_with_schemata, - }, + schema::{AvroSchema, ResolvedOwnedSchema, Schema}, types::Value, - util, }; -use log::warn; +use futures::AsyncRead; use serde::de::DeserializeOwned; -use serde_json::from_slice; -use std::{ - collections::HashMap, - io::{ErrorKind, Read}, - marker::PhantomData, - str::FromStr, -}; +use std::{io::Read, marker::PhantomData}; -/// Internal Block reader. -#[derive(Debug, Clone)] -struct Block<'r, R> { - reader: R, - /// Internal buffering to reduce allocation. - buf: Vec, - buf_idx: usize, - /// Number of elements expected to exist within this block. - message_count: usize, - marker: [u8; 16], - codec: Codec, - writer_schema: Schema, - schemata: Vec<&'r Schema>, - user_metadata: HashMap>, - names_refs: Names, +pub mod async_reader { + #[doc(inline)] + pub use crate::state_machines::reading::async_impl::{ + Reader, from_avro_datum, from_avro_datum_reader_schemata, from_avro_datum_schemata, + }; } -impl<'r, R: Read> Block<'r, R> { - fn new(reader: R, schemata: Vec<&'r Schema>) -> AvroResult> { - let mut block = Block { - reader, - codec: Codec::Null, - writer_schema: Schema::Null, - schemata, - buf: vec![], - buf_idx: 0, - message_count: 0, - marker: [0; 16], - user_metadata: Default::default(), - names_refs: Default::default(), - }; - - block.read_header()?; - Ok(block) - } - - /// Try to read the header and to set the writer `Schema`, the `Codec` and the marker based on - /// its content. - fn read_header(&mut self) -> AvroResult<()> { - let mut buf = [0u8; 4]; - self.reader - .read_exact(&mut buf) - .map_err(Details::ReadHeader)?; - - if buf != [b'O', b'b', b'j', 1u8] { - return Err(Details::HeaderMagic.into()); - } - - let meta_schema = Schema::map(Schema::Bytes); - match decode(&meta_schema, &mut self.reader)? { - Value::Map(metadata) => { - self.read_writer_schema(&metadata)?; - self.codec = read_codec(&metadata)?; - - for (key, value) in metadata { - if key == "avro.schema" - || key == "avro.codec" - || key == "avro.codec.compression_level" - { - // already processed - } else if key.starts_with("avro.") { - warn!("Ignoring unknown metadata key: {key}"); - } else { - self.read_user_metadata(key, value); - } - } - } - _ => { - return Err(Details::GetHeaderMetadata.into()); - } - } - - self.reader - .read_exact(&mut self.marker) - .map_err(|e| Details::ReadMarker(e).into()) - } - - fn fill_buf(&mut self, n: usize) -> AvroResult<()> { - // The buffer needs to contain exactly `n` elements, otherwise codecs will potentially read - // invalid bytes. - // - // The are two cases to handle here: - // - // 1. `n > self.buf.len()`: - // In this case we call `Vec::resize`, which guarantees that `self.buf.len() == n`. - // 2. `n < self.buf.len()`: - // We need to resize to ensure that the buffer len is safe to read `n` elements. - // - // TODO: Figure out a way to avoid having to truncate for the second case. - self.buf.resize(util::safe_len(n)?, 0); - self.reader - .read_exact(&mut self.buf) - .map_err(Details::ReadIntoBuf)?; - self.buf_idx = 0; - Ok(()) - } - - /// Try to read a data block, also performing schema resolution for the objects contained in - /// the block. The objects are stored in an internal buffer to the `Reader`. - fn read_block_next(&mut self) -> AvroResult<()> { - assert!(self.is_empty(), "Expected self to be empty!"); - match util::read_long(&mut self.reader).map_err(Error::into_details) { - Ok(block_len) => { - self.message_count = block_len as usize; - let block_bytes = util::read_long(&mut self.reader)?; - self.fill_buf(block_bytes as usize)?; - let mut marker = [0u8; 16]; - self.reader - .read_exact(&mut marker) - .map_err(Details::ReadBlockMarker)?; - - if marker != self.marker { - return Err(Details::GetBlockMarker.into()); - } - - // NOTE (JAB): This doesn't fit this Reader pattern very well. - // `self.buf` is a growable buffer that is reused as the reader is iterated. - // For non `Codec::Null` variants, `decompress` will allocate a new `Vec` - // and replace `buf` with the new one, instead of reusing the same buffer. - // We can address this by using some "limited read" type to decode directly - // into the buffer. But this is fine, for now. - self.codec.decompress(&mut self.buf) - } - Err(Details::ReadVariableIntegerBytes(io_err)) => { - if let ErrorKind::UnexpectedEof = io_err.kind() { - // to not return any error in case we only finished to read cleanly from the stream - Ok(()) - } else { - Err(Details::ReadVariableIntegerBytes(io_err).into()) - } - } - Err(e) => Err(Error::new(e)), - } - } - - fn len(&self) -> usize { - self.message_count - } - - fn is_empty(&self) -> bool { - self.len() == 0 - } - - fn read_next(&mut self, read_schema: Option<&Schema>) -> AvroResult> { - if self.is_empty() { - self.read_block_next()?; - if self.is_empty() { - return Ok(None); - } - } - - let mut block_bytes = &self.buf[self.buf_idx..]; - let b_original = block_bytes.len(); - - let item = decode_internal( - &self.writer_schema, - &self.names_refs, - &None, - &mut block_bytes, - )?; - let item = match read_schema { - Some(schema) => item.resolve(schema)?, - None => item, - }; - - if b_original != 0 && b_original == block_bytes.len() { - // from_avro_datum did not consume any bytes, so return an error to avoid an infinite loop - return Err(Details::ReadBlock.into()); - } - self.buf_idx += b_original - block_bytes.len(); - self.message_count -= 1; - Ok(Some(item)) - } - - fn read_writer_schema(&mut self, metadata: &HashMap) -> AvroResult<()> { - let json: serde_json::Value = metadata - .get("avro.schema") - .and_then(|bytes| { - if let Value::Bytes(ref bytes) = *bytes { - from_slice(bytes.as_ref()).ok() - } else { - None - } - }) - .ok_or(Details::GetAvroSchemaFromMap)?; - if !self.schemata.is_empty() { - let rs = ResolvedSchema::try_from(self.schemata.clone())?; - let names: Names = rs - .get_names() - .iter() - .map(|(name, schema)| (name.clone(), (*schema).clone())) - .collect(); - self.writer_schema = Schema::parse_with_names(&json, names)?; - resolve_names_with_schemata(&self.schemata, &mut self.names_refs, &None)?; - } else { - self.writer_schema = Schema::parse(&json)?; - resolve_names(&self.writer_schema, &mut self.names_refs, &None)?; - } - Ok(()) - } - - fn read_user_metadata(&mut self, key: String, value: Value) { - match value { - Value::Bytes(ref vec) => { - self.user_metadata.insert(key, vec.clone()); - } - wrong => { - warn!("User metadata values must be Value::Bytes, found {wrong:?}"); - } - } - } -} - -fn read_codec(metadata: &HashMap) -> AvroResult { - let result = metadata - .get("avro.codec") - .map(|codec| { - if let Value::Bytes(ref bytes) = *codec { - match std::str::from_utf8(bytes.as_ref()) { - Ok(utf8) => Ok(utf8), - Err(utf8_error) => Err(Details::ConvertToUtf8Error(utf8_error).into()), - } - } else { - Err(Details::BadCodecMetadata.into()) - } - }) - .map(|codec_res| match codec_res { - Ok(codec) => match Codec::from_str(codec) { - Ok(codec) => match codec { - #[cfg(feature = "bzip")] - Codec::Bzip2(_) => { - use crate::Bzip2Settings; - if let Some(Value::Bytes(bytes)) = - metadata.get("avro.codec.compression_level") - { - Ok(Codec::Bzip2(Bzip2Settings::new(bytes[0]))) - } else { - Ok(codec) - } - } - #[cfg(feature = "xz")] - Codec::Xz(_) => { - use crate::XzSettings; - if let Some(Value::Bytes(bytes)) = - metadata.get("avro.codec.compression_level") - { - Ok(Codec::Xz(XzSettings::new(bytes[0]))) - } else { - Ok(codec) - } - } - #[cfg(feature = "zstandard")] - Codec::Zstandard(_) => { - use crate::ZstandardSettings; - if let Some(Value::Bytes(bytes)) = - metadata.get("avro.codec.compression_level") - { - Ok(Codec::Zstandard(ZstandardSettings::new(bytes[0]))) - } else { - Ok(codec) - } - } - _ => Ok(codec), - }, - Err(_) => Err(Details::CodecNotSupported(codec.to_owned()).into()), - }, - Err(err) => Err(err), - }); - - result.unwrap_or(Ok(Codec::Null)) -} - -/// Main interface for reading Avro formatted values. -/// -/// To be used as an iterator: +/// Reader for Avro objects created using the [single-object encoding]. /// -/// ```no_run -/// # use apache_avro::Reader; -/// # use std::io::Cursor; -/// # let input = Cursor::new(Vec::::new()); -/// for value in Reader::new(input).unwrap() { -/// match value { -/// Ok(v) => println!("{:?}", v), -/// Err(e) => println!("Error: {}", e), -/// }; -/// } -/// ``` -pub struct Reader<'a, R> { - block: Block<'a, R>, - reader_schema: Option<&'a Schema>, - errored: bool, - should_resolve_schema: bool, -} - -impl<'a, R: Read> Reader<'a, R> { - /// Creates a `Reader` given something implementing the `io::Read` trait to read from. - /// No reader `Schema` will be set. - /// - /// **NOTE** The avro header is going to be read automatically upon creation of the `Reader`. - pub fn new(reader: R) -> AvroResult> { - let block = Block::new(reader, vec![])?; - let reader = Reader { - block, - reader_schema: None, - errored: false, - should_resolve_schema: false, - }; - Ok(reader) - } - - /// Creates a `Reader` given a reader `Schema` and something implementing the `io::Read` trait - /// to read from. - /// - /// **NOTE** The avro header is going to be read automatically upon creation of the `Reader`. - pub fn with_schema(schema: &'a Schema, reader: R) -> AvroResult> { - let block = Block::new(reader, vec![schema])?; - let mut reader = Reader { - block, - reader_schema: Some(schema), - errored: false, - should_resolve_schema: false, - }; - // Check if the reader and writer schemas disagree. - reader.should_resolve_schema = reader.writer_schema() != schema; - Ok(reader) - } - - /// Creates a `Reader` given a reader `Schema` and something implementing the `io::Read` trait - /// to read from. - /// - /// **NOTE** The avro header is going to be read automatically upon creation of the `Reader`. - pub fn with_schemata( - schema: &'a Schema, - schemata: Vec<&'a Schema>, - reader: R, - ) -> AvroResult> { - let block = Block::new(reader, schemata)?; - let mut reader = Reader { - block, - reader_schema: Some(schema), - errored: false, - should_resolve_schema: false, - }; - // Check if the reader and writer schemas disagree. - reader.should_resolve_schema = reader.writer_schema() != schema; - Ok(reader) - } - - /// Get a reference to the writer `Schema`. - #[inline] - pub fn writer_schema(&self) -> &Schema { - &self.block.writer_schema - } - - /// Get a reference to the optional reader `Schema`. - #[inline] - pub fn reader_schema(&self) -> Option<&Schema> { - self.reader_schema - } - - /// Get a reference to the user metadata - #[inline] - pub fn user_metadata(&self) -> &HashMap> { - &self.block.user_metadata - } - - #[inline] - fn read_next(&mut self) -> AvroResult> { - let read_schema = if self.should_resolve_schema { - self.reader_schema - } else { - None - }; - - self.block.read_next(read_schema) - } -} - -impl Iterator for Reader<'_, R> { - type Item = AvroResult; - - fn next(&mut self) -> Option { - // to prevent keep on reading after the first error occurs - if self.errored { - return None; - }; - match self.read_next() { - Ok(opt) => opt.map(Ok), - Err(e) => { - self.errored = true; - Some(Err(e)) - } - } - } -} - -/// Decode a `Value` encoded in Avro format given its `Schema` and anything implementing `io::Read` -/// to read from. -/// -/// In case a reader `Schema` is provided, schema resolution will also be performed. -/// -/// **NOTE** This function has a quite small niche of usage and does NOT take care of reading the -/// header and consecutive data blocks; use [`Reader`](struct.Reader.html) if you don't know what -/// you are doing, instead. -pub fn from_avro_datum( - writer_schema: &Schema, - reader: &mut R, - reader_schema: Option<&Schema>, -) -> AvroResult { - let value = decode(writer_schema, reader)?; - match reader_schema { - Some(schema) => value.resolve(schema), - None => Ok(value), - } -} - -/// Decode a `Value` encoded in Avro format given the provided `Schema` and anything implementing `io::Read` -/// to read from. -/// If the writer schema is incomplete, i.e. contains `Schema::Ref`s then it will use the provided -/// schemata to resolve any dependencies. -/// -/// In case a reader `Schema` is provided, schema resolution will also be performed. -pub fn from_avro_datum_schemata( - writer_schema: &Schema, - writer_schemata: Vec<&Schema>, - reader: &mut R, - reader_schema: Option<&Schema>, -) -> AvroResult { - from_avro_datum_reader_schemata( - writer_schema, - writer_schemata, - reader, - reader_schema, - Vec::with_capacity(0), - ) -} - -/// Decode a `Value` encoded in Avro format given the provided `Schema` and anything implementing `io::Read` -/// to read from. -/// If the writer schema is incomplete, i.e. contains `Schema::Ref`s then it will use the provided -/// schemata to resolve any dependencies. -/// -/// In case a reader `Schema` is provided, schema resolution will also be performed. -pub fn from_avro_datum_reader_schemata( - writer_schema: &Schema, - writer_schemata: Vec<&Schema>, - reader: &mut R, - reader_schema: Option<&Schema>, - reader_schemata: Vec<&Schema>, -) -> AvroResult { - let rs = ResolvedSchema::try_from(writer_schemata)?; - let value = decode_internal(writer_schema, rs.get_names(), &None, reader)?; - match reader_schema { - Some(schema) => { - if reader_schemata.is_empty() { - value.resolve(schema) - } else { - value.resolve_schemata(schema, reader_schemata) - } - } - None => Ok(value), - } -} - +/// [single-object encoding]: https://avro.apache.org/docs/++version++/specification/#single-object-encoding pub struct GenericSingleObjectReader { write_schema: ResolvedOwnedSchema, expected_header: Vec, } impl GenericSingleObjectReader { + /// Create a reader for the given schema. + /// + /// This will expect the input to use the [`RabinFingerprintHeader`]. pub fn new(schema: Schema) -> AvroResult { let header_builder = RabinFingerprintHeader::from_schema(&schema); Self::new_with_header_builder(schema, header_builder) } + /// Create a reader for the given schema with a custom fingerprint. + /// + /// See [`HeaderBuilder`] for details on how to implement a custom fingerprint. pub fn new_with_header_builder( schema: Schema, header_builder: HB, @@ -525,17 +70,36 @@ impl GenericSingleObjectReader { }) } + /// Read a [`Value`] from the reader. pub fn read_value(&self, reader: &mut R) -> AvroResult { let mut header = vec![0; self.expected_header.len()]; match reader.read_exact(&mut header) { Ok(_) => { if self.expected_header == header { - decode_internal( - self.write_schema.get_root_schema(), - self.write_schema.get_names(), - &None, - reader, + from_avro_datum(self.write_schema.get_root_schema(), reader, None) + } else { + Err( + Details::SingleObjectHeaderMismatch(self.expected_header.clone(), header) + .into(), ) + } + } + Err(io_error) => Err(Details::ReadHeader(io_error).into()), + } + } + + pub async fn read_value_async( + &self, + reader: &mut R, + ) -> AvroResult { + use futures::AsyncReadExt as _; + + let mut header = vec![0; self.expected_header.len()]; + match reader.read_exact(&mut header).await { + Ok(_) => { + if self.expected_header == header { + async_reader::from_avro_datum(self.write_schema.get_root_schema(), reader, None) + .await } else { Err( Details::SingleObjectHeaderMismatch(self.expected_header.clone(), header) @@ -548,6 +112,9 @@ impl GenericSingleObjectReader { } } +/// Reader for Avro objects created using the [single-object encoding] deserializing directly to `T`. +/// +/// [single-object encoding]: https://avro.apache.org/docs/++version++/specification/#single-object-encoding pub struct SpecificSingleObjectReader where T: AvroSchema, @@ -560,6 +127,7 @@ impl SpecificSingleObjectReader where T: AvroSchema, { + /// Create the reader from the schema associated with `T`. pub fn new() -> AvroResult> { Ok(SpecificSingleObjectReader { inner: GenericSingleObjectReader::new(T::get_schema())?, @@ -572,21 +140,37 @@ impl SpecificSingleObjectReader where T: AvroSchema + From, { + /// Read a `T` from the reader. pub fn read_from_value(&self, reader: &mut R) -> AvroResult { self.inner.read_value(reader).map(|v| v.into()) } + + /// Read a `T` from the reader. + pub async fn read_from_value_async( + &self, + reader: &mut R, + ) -> AvroResult { + self.inner.read_value_async(reader).await.map(|v| v.into()) + } } impl SpecificSingleObjectReader where T: AvroSchema + DeserializeOwned, { + /// Read a `T` from the reader. pub fn read(&self, reader: &mut R) -> AvroResult { from_value::(&self.inner.read_value(reader)?) } + + pub async fn read_async(&self, reader: &mut R) -> AvroResult { + from_value::(&self.inner.read_value_async(reader).await?) + } } -/// Reads the marker bytes from Avro bytes generated earlier by a `Writer` +/// Reads the marker bytes from Avro bytes generated earlier by a [`Writer`]. +/// +/// [`Writer`]: crate::Writer pub fn read_marker(bytes: &[u8]) -> [u8; 16] { assert!( bytes.len() > 16, @@ -600,11 +184,13 @@ pub fn read_marker(bytes: &[u8]) -> [u8; 16] { #[cfg(test)] mod tests { use super::*; - use crate::{encode::encode, headers::GlueSchemaUuidHeader, rabin::Rabin, types::Record}; + use crate::{ + Error, encode::encode, headers::GlueSchemaUuidHeader, rabin::Rabin, types::Record, + }; use apache_avro_test_helper::TestResult; use pretty_assertions::assert_eq; use serde::Deserialize; - use std::io::Cursor; + use std::{collections::HashMap, io::Cursor}; use uuid::Uuid; const SCHEMA: &str = r#" @@ -704,22 +290,26 @@ mod tests { let schema = Schema::parse_str(TEST_RECORD_SCHEMA_3240)?; let mut encoded: &'static [u8] = &[54, 6, 102, 111, 111]; - let expected_record: TestRecord3240 = TestRecord3240 { - a: 27i64, - b: String::from("foo"), - a_nullable_array: None, - a_nullable_string: None, - }; + // The schema used to read is not compatible with what is written + assert!(from_avro_datum(&schema, &mut encoded, None).is_err()); - let avro_datum = from_avro_datum(&schema, &mut encoded, None)?; - let parsed_record: TestRecord3240 = match &avro_datum { - Value::Record(_) => from_value::(&avro_datum)?, - unexpected => { - panic!("could not map avro data to struct, found unexpected: {unexpected:?}") - } - }; + // let avro_datum = from_avro_datum(&schema, &mut encoded, None)?; - assert_eq!(parsed_record, expected_record); + // let expected_record: TestRecord3240 = TestRecord3240 { + // a: 27i64, + // b: String::from("foo"), + // a_nullable_array: None, + // a_nullable_string: None, + // }; + + // let parsed_record: TestRecord3240 = match &avro_datum { + // Value::Record(_) => from_value::(&avro_datum)?, + // unexpected => { + // panic!("could not map avro data to struct, found unexpected: {unexpected:?}") + // } + // }; + // + // assert_eq!(parsed_record, expected_record); Ok(()) } @@ -780,10 +370,13 @@ mod tests { .into_iter() .rev() .collect::>(); - let reader = Reader::with_schema(&schema, &invalid[..])?; - for value in reader { - assert!(value.is_err()); - } + let mut reader = Reader::with_schema(&schema, &invalid[..])?; + + // The block says it contains 2 values, but only contains one. + // The first value is successfully decoded + let _v = reader.next().unwrap().unwrap(); + // The second fails with an unexpected end of file error. + assert!(reader.next().unwrap().is_err()); Ok(()) } @@ -815,10 +408,7 @@ mod tests { let mut writer = Writer::new(&schema, Vec::new())?; let mut user_meta_data: HashMap> = HashMap::new(); - user_meta_data.insert( - "stringKey".to_string(), - "stringValue".to_string().into_bytes(), - ); + user_meta_data.insert("stringKey".to_string(), b"stringValue".to_vec()); user_meta_data.insert("bytesKey".to_string(), b"bytesValue".to_vec()); user_meta_data.insert("vecKey".to_string(), vec![1, 2, 3]); diff --git a/avro/src/state_machines/mod.rs b/avro/src/state_machines/mod.rs new file mode 100644 index 00000000..28157eae --- /dev/null +++ b/avro/src/state_machines/mod.rs @@ -0,0 +1 @@ +pub mod reading; diff --git a/avro/src/state_machines/reading/async_impl.rs b/avro/src/state_machines/reading/async_impl.rs new file mode 100644 index 00000000..c4000f1f --- /dev/null +++ b/avro/src/state_machines/reading/async_impl.rs @@ -0,0 +1,302 @@ +use async_stream::try_stream; +use futures::{AsyncRead, AsyncReadExt, Stream}; +use oval::Buffer; +use serde::Deserialize; +use std::collections::HashMap; + +use crate::{ + AvroResult, Error, Schema, + error::Details, + schema::{resolve_names, resolve_names_with_schemata}, + state_machines::reading::{ + ItemRead, StateMachine, StateMachineControlFlow, + commands::CommandTape, + datum::DatumStateMachine, + deserialize_from_tape, + object_container_file::{ + ObjectContainerFileBodyStateMachine, ObjectContainerFileHeader, + ObjectContainerFileHeaderStateMachine, + }, + value_from_tape, + }, + types::Value, +}; + +// This should probably also be a state machine and be wrapped in sync and async versions. +// But this suffices for the demonstration. +pub struct Reader<'a, R> { + reader_schema: Option<&'a Schema>, + header: ObjectContainerFileHeader, + fsm: Option, + reader: R, + buffer: Buffer, +} + +impl<'a, R: AsyncRead + Unpin> Reader<'a, R> { + /// Creates a [`crate::Reader`] that will use the schema from the file header. + /// + /// No reader [`Schema`] will be set. + /// + /// **NOTE** The Avro header is going to be read automatically upon creation of the [`crate::Reader`]. + pub async fn new(reader: R) -> Result { + Self::new_inner(reader, None, Vec::new()).await + } + + /// Creates a [`crate::Reader`] that will use the given schema for schema resolution. + /// + /// **NOTE** The Avro header is going to be read automatically upon creation of the [`crate::Reader`]. + pub async fn with_schema(schema: &'a Schema, reader: R) -> Result { + Self::new_inner(reader, Some(schema), Vec::new()).await + } + + /// Creates a [`crate::Reader`] that will use the given schema for schema resolution. + /// + /// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be + /// resolved and an error will be returned. + /// + /// Any [`Schema::Ref`] will be resolved using the schemata. + /// + /// **NOTE** The avro header is going to be read automatically upon creation of the [`crate::Reader`]. + pub async fn with_schemata( + schema: &'a Schema, + schemata: Vec<&'a Schema>, + reader: R, + ) -> Result { + Self::new_inner(reader, Some(schema), schemata).await + } + + /// Get a reference to the optional reader [`Schema`]. + /// + /// This will only be set if there was a reader schema provided *and* it differed from the + /// writer schema. + pub fn reader_schema(&self) -> Option<&'a Schema> { + self.reader_schema + } + + /// Get a reference to the user metadata. + pub fn user_metadata(&self) -> &HashMap> { + &self.header.metadata + } + + /// Get a reference to the file header. + pub fn header(&self) -> &ObjectContainerFileHeader { + &self.header + } + + async fn new_inner( + mut reader: R, + reader_schema: Option<&'a Schema>, + schemata: Vec<&'a Schema>, + ) -> Result { + // Read a maximum of 2Kb per read + let mut buffer = Buffer::with_capacity(2 * 1024); + + // Parse the header + let mut fsm = ObjectContainerFileHeaderStateMachine::new(schemata); + let header = loop { + // Fill the buffer + let n = reader + .read(buffer.space()) + .await + .map_err(Details::ReadHeader)?; + if n == 0 { + return Err(Details::ReadHeader(std::io::ErrorKind::UnexpectedEof.into()).into()); + } + buffer.fill(n); + + // Start/continue the state machine + match fsm.parse(&mut buffer)? { + StateMachineControlFlow::NeedMore(new_fsm) => fsm = new_fsm, + StateMachineControlFlow::Done(header) => break header, + } + }; + + let tape = CommandTape::build_from_schema(&header.schema, &header.names)?; + + let reader_schema = if let Some(schema) = reader_schema + && schema != &header.schema + { + Some(schema) + } else { + None + }; + + Ok(Self { + reader_schema, + fsm: Some(ObjectContainerFileBodyStateMachine::new( + tape, + header.sync, + header.codec, + )), + header, + reader, + buffer, + }) + } + + /// Get the next object in the file + async fn next_object(&mut self) -> Option, Error>> { + if let Some(mut fsm) = self.fsm.take() { + loop { + match fsm.parse(&mut self.buffer) { + Ok(StateMachineControlFlow::NeedMore(new_fsm)) => { + fsm = new_fsm; + let n = match self.reader.read(self.buffer.space()).await { + Ok(0) => { + return Some(Err(Details::ReadIntoBuf( + std::io::ErrorKind::UnexpectedEof.into(), + ) + .into())); + } + Ok(n) => n, + Err(e) => return Some(Err(Details::ReadIntoBuf(e).into())), + }; + self.buffer.fill(n); + } + Ok(StateMachineControlFlow::Done(Some((object, fsm)))) => { + self.fsm.replace(fsm); + return Some(Ok(object)); + } + Ok(StateMachineControlFlow::Done(None)) => { + return None; + } + Err(e) => { + return Some(Err(e)); + } + } + } + } + None + } + + pub async fn stream_serde<'b, T: Deserialize<'b>>( + &mut self, + ) -> impl Stream> { + assert!( + self.reader_schema.is_none(), + "Reader schema is not supported with Serde!" + ); + try_stream! { + while let Some(object) = self.next_object().await { + let mut tape = object?; + yield deserialize_from_tape(&mut tape, &self.header.schema)?; + } + } + } + + pub async fn stream(&mut self) -> impl Stream> { + try_stream! { + while let Some(object) = self.next_object().await { + let mut tape = object?; + + let value = value_from_tape(&mut tape, &self.header.schema, &self.header.names)?; + let resolved = if let Some(schema) = self.reader_schema { + value.resolve_internal(schema, &self.header.names, &None, &None)? + } else { + value + }; + yield resolved; + } + } + } +} + +/// Decode a raw Avro datum using the provided [`Schema`]. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +/// +/// **NOTE** This function is very niche and does NOT take care of reading the header and +/// consecutive data blocks. use [`Reader`] if you just want to read an Avro encoded file. +pub async fn from_avro_datum( + writer_schema: &Schema, + reader: &mut R, + reader_schema: Option<&Schema>, +) -> AvroResult { + from_avro_datum_reader_schemata(writer_schema, Vec::new(), reader, reader_schema, Vec::new()) + .await +} + +/// Decode a raw Avro datum using the provided [`Schema`] and schemata. +/// +/// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be +/// resolved and an error will be returned. +/// +/// If the writer schema contains any [`Schema::Ref`] then it will use the provided +/// schemata to resolve any dependencies. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +pub async fn from_avro_datum_schemata( + writer_schema: &Schema, + writer_schemata: Vec<&Schema>, + reader: &mut R, + reader_schema: Option<&Schema>, +) -> AvroResult { + from_avro_datum_reader_schemata( + writer_schema, + writer_schemata, + reader, + reader_schema, + Vec::new(), + ) + .await +} + +/// Decode a raw Avro datum using the provided [`Schema`] and schemata. +/// +/// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be +/// resolved and an error will be returned. +/// +/// If the writer schema contains any [`Schema::Ref`] then it will use the provided +/// schemata to resolve any dependencies. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +pub async fn from_avro_datum_reader_schemata( + writer_schema: &Schema, + writer_schemata: Vec<&Schema>, + reader: &mut R, + reader_schema: Option<&Schema>, + reader_schemata: Vec<&Schema>, +) -> AvroResult { + let mut names = HashMap::new(); + if writer_schemata.is_empty() { + resolve_names(writer_schema, &mut names, &None)?; + } else { + resolve_names_with_schemata(&writer_schemata, &mut names, &None)?; + } + + let tape = CommandTape::build_from_schema(writer_schema, &names)?; + + // Read a maximum of 2Kb per read + let mut buffer = Buffer::with_capacity(2 * 1024); + let mut fsm = DatumStateMachine::new(tape); + let value = loop { + // Fill the buffer + let n = reader + .read(buffer.space()) + .await + .map_err(Details::ReadIntoBuf)?; + if n == 0 { + return Err(Details::ReadIntoBuf(std::io::ErrorKind::UnexpectedEof.into()).into()); + } + buffer.fill(n); + + match fsm.parse(&mut buffer)? { + StateMachineControlFlow::NeedMore(new_fsm) => { + fsm = new_fsm; + } + StateMachineControlFlow::Done(mut tape) => { + break value_from_tape(&mut tape, writer_schema, &names)?; + } + } + }; + match reader_schema { + Some(schema) => { + if reader_schemata.is_empty() { + value.resolve(schema) + } else { + value.resolve_schemata(schema, reader_schemata) + } + } + None => Ok(value), + } +} diff --git a/avro/src/state_machines/reading/block.rs b/avro/src/state_machines/reading/block.rs new file mode 100644 index 00000000..cf03f6f2 --- /dev/null +++ b/avro/src/state_machines/reading/block.rs @@ -0,0 +1,110 @@ +use oval::Buffer; + +use crate::{ + Error, + error::Details, + state_machines::reading::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, datum::DatumStateMachine, + decode_zigzag_buffer, + }, +}; + +/// Are we currently parsing an object or just finished/reading a block header +enum TapeOrFsm { + Tape(Vec), + Fsm(DatumStateMachine), +} + +pub struct BlockStateMachine { + command_tape: CommandTape, + tape_or_fsm: TapeOrFsm, + left_in_current_block: usize, + need_to_read_block_byte_size: bool, +} + +impl BlockStateMachine { + pub fn new_with_tape(command_tape: CommandTape, tape: Vec) -> Self { + Self { + // This clone is *cheap* + command_tape, + tape_or_fsm: TapeOrFsm::Tape(tape), + left_in_current_block: 0, + need_to_read_block_byte_size: false, + } + } +} + +impl StateMachine for BlockStateMachine { + type Output = Vec; + fn parse( + mut self, + buffer: &mut Buffer, + ) -> Result, Error> { + loop { + match self.tape_or_fsm { + TapeOrFsm::Tape(mut tape) => { + // If we finished the last block (or are newly created) read the block info + if self.left_in_current_block == 0 { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.tape_or_fsm = TapeOrFsm::Tape(tape); + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + + // Need to read the block byte size when block is negative + self.need_to_read_block_byte_size = block.is_negative(); + + // We do the rest with the absolute block size + let abs_block = usize::try_from(block.unsigned_abs()) + .map_err(|e| Details::ConvertU64ToUsize(e, block.unsigned_abs()))?; + self.left_in_current_block = abs_block; + tape.push(ItemRead::Block(abs_block)); + + // Done parsing the blocks + if abs_block == 0 { + return Ok(StateMachineControlFlow::Done(tape)); + } + } + + // If the block length was negative we need to read the block size + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.tape_or_fsm = TapeOrFsm::Tape(tape); + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + + // Make sure the value is sane + // TODO: Maybe use safe_len here? + let _ = usize::try_from(block) + .map_err(|e| Details::ConvertI64ToUsize(e, block))?; + + // This is not necessary, as it will be overwritten before being read again + // but it does show the intent more clearly + self.need_to_read_block_byte_size = false; + } + + // We've either finished reading the block header or the last object was read and + // left_in_current_block is not zero + self.tape_or_fsm = TapeOrFsm::Fsm(DatumStateMachine::new_with_tape( + self.command_tape.clone(), + tape, + )) + } + TapeOrFsm::Fsm(fsm) => { + // (Continue) reading the object + match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(fsm); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(tape) => { + self.tape_or_fsm = TapeOrFsm::Tape(tape); + self.left_in_current_block -= 1; + } + } + } + } + } + } +} diff --git a/avro/src/state_machines/reading/bytes.rs b/avro/src/state_machines/reading/bytes.rs new file mode 100644 index 00000000..68bc529c --- /dev/null +++ b/avro/src/state_machines/reading/bytes.rs @@ -0,0 +1,71 @@ +use oval::Buffer; + +use crate::{ + error::Details, + state_machines::reading::{StateMachine, StateMachineControlFlow, decode_zigzag_buffer}, +}; + +use super::StateMachineResult; + +// TODO: Also make a String specific state machine. This allows checking the utf-8 while parsing +// which would make the parser fail quicker on large invalid strings. +// TODO: This state machine could also produce inline strings (smolstr) for strings smaller than +// size_of::, and use some extra bits to store well-known strings +// like avro.schema and avro.codec as fixed strings. + +#[derive(Default)] +pub struct BytesStateMachine { + length: Option, + data: Vec, +} + +impl BytesStateMachine { + pub fn new() -> Self { + Self { + length: None, + data: Vec::new(), + } + } + + pub fn new_with_length(length: usize) -> Self { + Self { + length: Some(length), + data: Vec::with_capacity(length), + } + } +} + +impl StateMachine for BytesStateMachine { + // This is a Vec instead of a Box<[u8]> as it's easier to create a string from a vec + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + if self.length.is_none() { + let Some(length) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer varint byte plus we know + // there at least 127 bytes in the buffer now (as otherwise we wouldn't need one more varint byte). + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let length = + usize::try_from(length).map_err(|e| Details::ConvertI64ToUsize(e, length))?; + self.length = Some(length); + self.data.reserve_exact(length); + } + // This was just set in the previous if statement and it returns if that was not possible to do. + let Some(length) = self.length else { + unreachable!() + }; + + // How much more data is needed + let remaining = length - self.data.len(); + // How much of that is available in the buffer + let available = remaining.min(buffer.available_data()); + self.data.extend_from_slice(&buffer.data()[..available]); + buffer.consume(available); + if remaining - available == 0 { + Ok(StateMachineControlFlow::Done(self.data)) + } else { + Ok(StateMachineControlFlow::NeedMore(self)) + } + } +} diff --git a/avro/src/state_machines/reading/codec.rs b/avro/src/state_machines/reading/codec.rs new file mode 100644 index 00000000..b823ee6c --- /dev/null +++ b/avro/src/state_machines/reading/codec.rs @@ -0,0 +1,180 @@ +use crate::{ + Codec, + state_machines::reading::{StateMachine, StateMachineControlFlow, StateMachineResult}, +}; +use oval::Buffer; + +pub struct CodecStateMachine { + sub_machine: Option, + codec: Decoder, + buffer: Buffer, +} + +impl CodecStateMachine { + pub fn new(sub_machine: T, codec: Codec) -> Self { + Self { + sub_machine: Some(sub_machine), + codec: codec.into(), + buffer: Buffer::with_capacity(1024), + } + } + + pub fn reset(&mut self, sub_machine: T) { + self.buffer.reset(); + self.sub_machine = Some(sub_machine); + self.codec.reset(); + } +} + +pub enum Decoder { + Null, + Deflate(Box), + #[cfg(feature = "snappy")] + Snappy(snap::raw::Decoder), + #[cfg(feature = "zstandard")] + Zstandard(zstd::stream::raw::Decoder<'static>), + #[cfg(feature = "bzip")] + Bzip2(bzip2::Decompress), + #[cfg(feature = "xz")] + Xz(xz2::stream::Stream), +} + +impl From for Decoder { + fn from(value: Codec) -> Self { + match value { + Codec::Null => Self::Null, + Codec::Deflate(_) => { + use miniz_oxide::{DataFormat::Raw, inflate::stream::InflateState}; + Self::Deflate(InflateState::new_boxed(Raw)) + } + #[cfg(feature = "snappy")] + Codec::Snappy => Self::Snappy(snap::raw::Decoder::new()), + #[cfg(feature = "zstandard")] + Codec::Zstandard(_) => Self::Zstandard(zstd::stream::raw::Decoder::new().unwrap()), + #[cfg(feature = "bzip")] + Codec::Bzip2(_) => Self::Bzip2(bzip2::Decompress::new(false)), + #[cfg(feature = "xz")] + Codec::Xz(_) => Self::Xz(xz2::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap()), + } + } +} + +impl Decoder { + pub fn reset(&mut self) { + match self { + Decoder::Null => {} + Decoder::Deflate(decoder) => { + decoder.reset_as(miniz_oxide::inflate::stream::MinReset); + } + #[cfg(feature = "snappy")] + Decoder::Snappy(_decoder) => {} // No reset needed + #[cfg(feature = "zstandard")] + Decoder::Zstandard(decoder) => zstd::stream::raw::Operation::reinit(decoder).unwrap(), + #[cfg(feature = "bzip")] + Decoder::Bzip2(decoder) => { + // No reset/reinit API available + let _drop = std::mem::replace(decoder, bzip2::Decompress::new(false)); + } + #[cfg(feature = "xz")] + Decoder::Xz(decoder) => { + // No reset/reinit API available + let _drop = std::mem::replace( + decoder, + xz2::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap(), + ); + } + } + } +} + +impl StateMachine for CodecStateMachine { + type Output = (T::Output, Self); + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + let buffer = match &mut self.codec { + Decoder::Null => buffer, + Decoder::Deflate(decoder) => { + use miniz_oxide::{MZFlush, StreamResult, inflate::stream::inflate}; + let StreamResult { + bytes_consumed, + bytes_written, + status, + } = inflate(decoder, buffer.data(), self.buffer.space(), MZFlush::None); + status.unwrap(); + buffer.consume(bytes_consumed); + self.buffer.fill(bytes_written); + + &mut self.buffer + } + #[cfg(feature = "snappy")] + Decoder::Snappy(_decoder) => { + todo!("Snap has no streaming decoder") + } + #[cfg(feature = "zstandard")] + Decoder::Zstandard(decoder) => { + use zstd::stream::raw::{Operation, Status}; + let Status { + bytes_read, + bytes_written, + .. + } = decoder + .run_on_buffers(buffer.data(), self.buffer.space()) + .map_err(crate::error::Details::ZstdDecompress)?; + buffer.consume(bytes_read); + self.buffer.fill(bytes_written); + + &mut self.buffer + } + #[cfg(feature = "bzip")] + Decoder::Bzip2(decoder) => { + let prev_total_in = decoder.total_in(); + let prev_total_out = decoder.total_out(); + + let _status = decoder + .decompress(buffer.data(), self.buffer.space()) + .unwrap(); + + let consumed = decoder.total_in() - prev_total_in; + let filled = decoder.total_out() - prev_total_out; + + buffer.consume(usize::try_from(consumed).unwrap()); + self.buffer.fill(usize::try_from(filled).unwrap()); + + &mut self.buffer + } + #[cfg(feature = "xz")] + Decoder::Xz(decoder) => { + use xz2::stream::Action::Run; + + let prev_total_in = decoder.total_in(); + let prev_total_out = decoder.total_out(); + + let _status = decoder + .process(buffer.data(), self.buffer.space(), Run) + .unwrap(); + + let consumed = decoder.total_in() - prev_total_in; + let filled = decoder.total_out() - prev_total_out; + + buffer.consume(usize::try_from(consumed).unwrap()); + self.buffer.fill(usize::try_from(filled).unwrap()); + + &mut self.buffer + } + }; + match self + .sub_machine + .take() + .expect("CodecStateMachine was not reset!") + .parse(buffer)? + { + StateMachineControlFlow::NeedMore(fsm) => { + self.sub_machine = Some(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(result) => { + Ok(StateMachineControlFlow::Done((result, self))) + } + } + } +} diff --git a/avro/src/state_machines/reading/commands.rs b/avro/src/state_machines/reading/commands.rs new file mode 100644 index 00000000..734d5c04 --- /dev/null +++ b/avro/src/state_machines/reading/commands.rs @@ -0,0 +1,646 @@ +use crate::{ + Error, Schema, + error::Details, + schema::{ + ArraySchema, DecimalSchema, EnumSchema, FixedSchema, MapSchema, Name, Names, RecordSchema, + UnionSchema, + }, + state_machines::reading::{ + ItemRead, SubStateMachine, block::BlockStateMachine, bytes::BytesStateMachine, + datum::DatumStateMachine, union::UnionStateMachine, + }, +}; +use std::{collections::HashMap, ops::Range, sync::Arc}; + +/// The next item type that should be read. +#[must_use] +pub enum ToRead { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, + Enum, + Ref(CommandTape), + Fixed(usize), + Block(CommandTape), + Union { + variants: CommandTape, + num_variants: usize, + }, +} + +impl ToRead { + pub fn into_state_machine(self, read: Vec) -> SubStateMachine { + match self { + ToRead::Null => SubStateMachine::Null(read), + ToRead::Boolean => SubStateMachine::Bool(read), + ToRead::Int => SubStateMachine::Int(read), + ToRead::Long => SubStateMachine::Long(read), + ToRead::Float => SubStateMachine::Float(read), + ToRead::Double => SubStateMachine::Double(read), + ToRead::Enum => SubStateMachine::Enum(read), + ToRead::Bytes => SubStateMachine::Bytes { + fsm: BytesStateMachine::new(), + read, + }, + ToRead::String => SubStateMachine::String { + fsm: BytesStateMachine::new(), + read, + }, + ToRead::Fixed(length) => SubStateMachine::Bytes { + fsm: BytesStateMachine::new_with_length(length), + read, + }, + ToRead::Ref(commands) => { + SubStateMachine::Object(DatumStateMachine::new_with_tape(commands, read)) + } + ToRead::Block(commands) => { + SubStateMachine::Block(BlockStateMachine::new_with_tape(commands, read)) + } + ToRead::Union { + variants, + num_variants, + } => SubStateMachine::Union(UnionStateMachine::new_with_tape( + variants, + num_variants, + read, + )), + } + } +} + +impl std::fmt::Debug for ToRead { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Null => write!(f, "Null"), + Self::Boolean => write!(f, "Boolean"), + Self::Int => write!(f, "Int"), + Self::Long => write!(f, "Long"), + Self::Float => write!(f, "Float"), + Self::Double => write!(f, "Double"), + Self::Bytes => write!(f, "Bytes"), + Self::String => write!(f, "String"), + Self::Enum => write!(f, "Enum"), + // We don't show the Ref command as that could recurse forever + Self::Ref(_) => write!(f, "Ref<...>"), + Self::Fixed(arg0) => write!(f, "Fixed<{arg0}>"), + Self::Block(arg0) => f.debug_tuple("Block").field(arg0).finish(), + Self::Union { variants, .. } => f.debug_tuple("Union").field(variants).finish(), + } + } +} + +/// A section of a tape of commands. +/// +/// This has a reference to the entire tape, so that references to types (for Union,Map,Array) can be resolved. +#[derive(Clone, PartialEq)] +#[must_use] +pub struct CommandTape { + inner: Arc<[u8]>, + read_range: Range, +} + +impl CommandTape { + pub const NULL: u8 = 0; + pub const BOOLEAN: u8 = 1; + pub const INT: u8 = 2; + pub const LONG: u8 = 3; + pub const FLOAT: u8 = 4; + pub const DOUBLE: u8 = 5; + pub const BYTES: u8 = 6; + pub const STRING: u8 = 7; + pub const ENUM: u8 = 8; + /// A fixed amount of bytes. + /// + /// If the amount of bytes is smaller than or equal to `0xF`, the amount is stored in the four + /// most significant bits of the byte. Otherwise, it's stored as a native endian usize directly + /// after the command byte. + pub const FIXED: u8 = 9; + /// A block based format follows (i.e. Map or Array). + /// + /// The command sequence of the type in the block follows immediately after the command byte. + /// The length of the sequence is stored in the most significant four bits of the command byte. + /// If the sequence is larger than `0xF`, then either the entire sequence or part of it is + /// put behind a [`Self::REF`]. + pub const BLOCK: u8 = 10; + pub const UNION: u8 = 11; + /// A reference to a command sequence somewhere else in the tape. + /// + /// If the length of the sequence is smaller than or equal to `0xF`, the length is stored in the + /// four most significant bits of the byte. Otherwise, it's stored as a native endian usize + /// directly after the command byte. After the length follows the offset as a native endian + /// usize. + pub const REF: u8 = 12; + /// Skip the next `n` commands. + /// + /// A SKIP command is not counted as a command. + /// + /// If `n` is smaller than or equal to `0xF`, the amount is stored in the four most significant + /// bits of the byte. Otherwise, it's stored as a native endian usize directly after the command + /// byte. + pub const SKIP: u8 = 13; + + /// Create a new tape that will be read from start to end. + pub fn new(command_tape: Arc<[u8]>) -> Self { + let length = command_tape.len(); + Self { + inner: command_tape, + read_range: 0..length, + } + } + + pub fn build_from_schema(schema: &Schema, names: &Names) -> Result { + CommandTapeBuilder::build(schema, names) + } + + /// Check if the section of the tape we're reading is finished. + pub fn is_finished(&self) -> bool { + self.read_range.is_empty() + } + + /// Extract a part from the tape to give to a sub-state machine. + /// + /// The tape will run from offset for the given amount of commands. + pub fn extract(&self, offset: usize, commands: usize) -> Self { + let mut temp = Self { + inner: self.inner.clone(), + read_range: offset..self.inner.len(), + }; + temp.skip(commands); + let max_index = temp.read_range.next().unwrap_or(self.inner.len()); + + assert!( + max_index <= self.inner.len(), + "Reference is (partly) outside the tape" + ); + Self { + inner: self.inner.clone(), + read_range: offset..max_index, + } + } + + /// Extract many parts from the tape to give to the Union state machine. + /// + /// The tapes will run from start to end (inclusive). + pub fn extract_many(&self, parts: &[(usize, usize)]) -> Box<[Self]> { + let mut vec = Vec::with_capacity(parts.len()); + for &(start, end) in parts { + vec.push(self.extract(start, end)); + } + vec.into_boxed_slice() + } + + /// Read an array of bytes from the tape. + fn read_array(&mut self) -> [u8; N] { + let start = self.read_range.next().expect("Read past the limit"); + let end = self.read_range.nth(N - 2).expect("Read past the limit"); + self.inner[start..=end].try_into().expect("Unreachable!") + } + + fn read_inline_or(&mut self, byte: u8) -> usize { + if byte >> 4 != 0 { + // Length is stored inline + (byte >> 4) as usize + } else { + usize::from_ne_bytes(self.read_array()) + } + } + + /// Get the next command from the tape. + /// + /// Will return `None` if exhausted. + pub fn command(&mut self) -> Option { + if let Some(position) = self.read_range.next() { + let byte = self.inner[position]; + match byte & 0xF { + Self::NULL => Some(ToRead::Null), + Self::BOOLEAN => Some(ToRead::Boolean), + Self::INT => Some(ToRead::Int), + Self::LONG => Some(ToRead::Long), + Self::FLOAT => Some(ToRead::Float), + Self::DOUBLE => Some(ToRead::Double), + Self::BYTES => Some(ToRead::Bytes), + Self::STRING => Some(ToRead::String), + Self::ENUM => Some(ToRead::Enum), + Self::FIXED => Some(ToRead::Fixed(self.read_inline_or(byte))), + Self::BLOCK => { + // ToRead::Block + let size = (byte >> 4) as usize; + self.skip(size); + Some(ToRead::Block(self.extract(position + 1, size))) + } + Self::UNION => { + // How many variants are there? + let num_variants = self.read_inline_or(byte); + + // Skip over the union variants while keeping track of their start and end + // so we can easily create the command tape + let start = self.read_range.start; + self.skip(num_variants); + let end = self.read_range.start; + + // Create the command tape from the previously tracked start and end + let mut tape = self.clone(); + tape.read_range.start = start; + tape.read_range.end = end; + + Some(ToRead::Union { + variants: tape, + num_variants, + }) + } + Self::REF => { + let size = self.read_inline_or(byte); + let offset = usize::from_ne_bytes(self.read_array()); + Some(ToRead::Ref(self.extract(offset, size))) + } + Self::SKIP => { + // Read how many commands to skip and skip them + let commands = self.read_inline_or(byte); + self.skip(commands); + + // Return the next command + self.command() + } + _ => unreachable!(), // TODO: There is room here to specialize certain types, like a Union of Null and some other type + } + } else { + None + } + } + + /// Skip `amount` commands. + /// + /// If a command contains subcommands, these will also be skipped. + /// + /// # Returns + /// `None` if it read past the end of the tape + pub(crate) fn skip(&mut self, mut amount: usize) -> Option<()> { + let mut i = 0; + while i < amount { + let position = self.read_range.next()?; + let byte = self.inner[position]; + match byte & 0xF { + CommandTape::BOOLEAN + | CommandTape::INT + | CommandTape::LONG + | CommandTape::FLOAT + | CommandTape::DOUBLE + | CommandTape::BYTES + | CommandTape::STRING + | CommandTape::ENUM + | CommandTape::NULL => {} + CommandTape::FIXED => { + let _size = self.read_inline_or(byte); + } + CommandTape::REF => { + let _size = self.read_inline_or(byte); + let _offset = usize::from_ne_bytes(self.read_array()); + } + CommandTape::UNION | CommandTape::BLOCK | CommandTape::SKIP => { + // These commands can inline other commands, so add them to the skip list + let num_variants = self.read_inline_or(byte); + amount += num_variants; + + // Skip does not count as a command, but we do increment `i` so we compensate + // for that by incrementing the amount + if byte & 0xF == CommandTape::SKIP { + amount += 1; + } + } + _ => unreachable!(), + } + i += 1; + } + Some(()) + } +} + +impl std::fmt::Debug for CommandTape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut c = self.clone(); + + write!(f, "CommandTape: ")?; + let mut list = f.debug_list(); + while let Some(command) = c.command() { + list.entry(&command); + } + list.finish() + } +} + +struct CommandTapeBuilder<'a> { + tape: Vec, + references: HashMap<&'a Name, (usize, usize)>, + names: &'a Names, +} + +impl<'a> CommandTapeBuilder<'a> { + pub fn new(names: &'a Names) -> Self { + Self { + tape: Vec::new(), + references: HashMap::new(), + names, + } + } + + fn add_schema(&mut self, schema: &'a Schema, inline_up_to: usize) -> Result { + match schema { + Schema::Null => { + self.tape.push(CommandTape::NULL); + Ok(1) + } + Schema::Boolean => { + self.tape.push(CommandTape::BOOLEAN); + Ok(1) + } + Schema::Int | Schema::Date | Schema::TimeMillis => { + self.tape.push(CommandTape::INT); + Ok(1) + } + Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + self.tape.push(CommandTape::LONG); + Ok(1) + } + Schema::Float => { + self.tape.push(CommandTape::FLOAT); + Ok(1) + } + Schema::Double => { + self.tape.push(CommandTape::DOUBLE); + Ok(1) + } + Schema::Bytes | Schema::BigDecimal => { + self.tape.push(CommandTape::BYTES); + Ok(1) + } + Schema::String | Schema::Uuid => { + self.tape.push(CommandTape::STRING); + Ok(1) + } + Schema::Array(ArraySchema { items, .. }) => { + let block_offset = self.tape.len(); + self.tape.push(CommandTape::BLOCK); + let commands = self.add_schema(items, 16)?; + self.tape[block_offset] = CommandTape::BLOCK | (commands << 4) as u8; + Ok(1) + } + Schema::Map(MapSchema { types, .. }) => { + let block_offset = self.tape.len(); + self.tape.push(CommandTape::BLOCK); + self.tape.push(CommandTape::STRING); + let commands = self.add_schema(types, 15)?; + self.tape[block_offset] = CommandTape::BLOCK | ((commands + 1) << 4) as u8; + Ok(1) + } + Schema::Union(UnionSchema { schemas, .. }) => { + let schema_len = schemas.len(); + if 0 < schema_len && schema_len <= 0xF { + self.tape.push(CommandTape::UNION | (schema_len << 4) as u8); + } else { + self.tape.push(CommandTape::UNION); + self.tape.extend_from_slice(&schema_len.to_ne_bytes()); + } + for schema in schemas { + self.add_schema(schema, 1)?; + } + Ok(1) + } + Schema::Record(RecordSchema { name, fields, .. }) => { + if let Some(&(offset, commands)) = self.references.get(name) { + self.add_reference(offset, commands); + Ok(1) + } else if fields.is_empty() { + panic!("Record has no fields! {schema:?}"); + } else { + let commands = fields.len(); + if commands > inline_up_to { + // If this record is larger than the amount we're allowed to inline, inject + // a SKIP command. + if commands <= 0xF { + self.tape.push(CommandTape::SKIP | (commands << 4) as u8); + } else { + self.tape.push(CommandTape::SKIP); + self.tape.extend_from_slice(&commands.to_ne_bytes()); + } + } + let offset = self.tape.len(); + self.references.insert(name, (offset, commands)); + for field in fields { + let _commands = self.add_schema(&field.schema, 1)?; + } + if commands > inline_up_to { + // Now refer back to the skip block + self.add_reference(offset, commands); + Ok(1) + } else { + Ok(commands) + } + } + } + Schema::Enum(EnumSchema { name, .. }) => { + let offset = self.tape.len(); + let commands = 1; + self.tape.push(CommandTape::ENUM); + self.references.insert(name, (offset, commands)); + Ok(1) + } + Schema::Fixed(FixedSchema { name, size, .. }) => { + let offset = self.tape.len(); + if 0 < *size && *size <= 0xF { + self.tape.push(CommandTape::FIXED | (*size << 4) as u8); + } else { + self.tape.push(CommandTape::FIXED); + self.tape.extend_from_slice(&size.to_ne_bytes()); + } + self.references.entry(name).or_insert((offset, 1)); + Ok(1) + } + Schema::Decimal(DecimalSchema { inner, .. }) => self.add_schema(inner, inline_up_to), + Schema::Duration => { + self.tape.push(CommandTape::FIXED | 12 << 4); + Ok(1) + } + Schema::Ref { name } => { + if let Some(&(offset, commands)) = self.references.get(name) { + self.add_reference(offset, commands); + Ok(1) + } else if let Some(schema) = self.names.get(name).as_ref() { + self.add_schema(schema, inline_up_to) + } else { + Err(Details::SchemaResolutionError(name.clone()).into()) + } + } + } + } + + fn add_reference(&mut self, offset: usize, commands: usize) { + if commands == 0 { + self.tape.push(CommandTape::NULL); + } else if commands <= 0xF { + self.tape.push(CommandTape::REF | (commands << 4) as u8); + } else { + self.tape.push(CommandTape::REF); + self.tape.extend_from_slice(&commands.to_ne_bytes()); + } + self.tape.extend_from_slice(&offset.to_ne_bytes()); + } + + pub fn build(schema: &Schema, names: &'a Names) -> Result { + let mut builder = Self::new(names); + + builder.add_schema(schema, usize::MAX)?; + + let tape_len = builder.tape.len(); + Ok(CommandTape { + inner: Arc::from(builder.tape), + read_range: 0..tape_len, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn command_tape_simple() { + assert_eq!( + CommandTape::build_from_schema(&Schema::Null, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::NULL] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Boolean, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::BOOLEAN] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Int, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Date, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimeMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Long, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimeMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampNanos, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampNanos, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Float, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::FLOAT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Double, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::DOUBLE] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Bytes, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::BYTES] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::String, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::STRING] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Uuid, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::STRING] + ); + } +} diff --git a/avro/src/state_machines/reading/datum.rs b/avro/src/state_machines/reading/datum.rs new file mode 100644 index 00000000..2e9a4beb --- /dev/null +++ b/avro/src/state_machines/reading/datum.rs @@ -0,0 +1,80 @@ +use oval::Buffer; + +use super::StateMachineResult; +use crate::state_machines::reading::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, SubStateMachine, +}; + +enum TapeOrFsm { + Tape(Vec), + Fsm(Box), +} + +pub struct DatumStateMachine { + command_tape: CommandTape, + tape_or_fsm: TapeOrFsm, +} + +impl DatumStateMachine { + /// Create a new state machine that reads a datum from the commands. + pub fn new(command_tape: CommandTape) -> Self { + Self::new_with_tape(command_tape, Vec::new()) + } + + /// Create a new state machine that appends to the tape (the tape is returned on completion). + pub fn new_with_tape(command_tape: CommandTape, tape: Vec) -> Self { + Self { + command_tape, + tape_or_fsm: TapeOrFsm::Tape(tape), + } + } +} + +impl StateMachine for DatumStateMachine { + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + // While there's data and commands to process we keep progressing the state machines + while !buffer.data().is_empty() { + match self.tape_or_fsm { + TapeOrFsm::Fsm(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(Box::new(fsm)); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(read) => { + self.tape_or_fsm = TapeOrFsm::Tape(read); + } + }, + TapeOrFsm::Tape(tape) => { + if let Some(command) = self.command_tape.command() { + let fsm = command.into_state_machine(tape); + // This is a duplicate of the TapeOrFsm::Fsm logic, but saves us an allocation + // by doing it immediately. + match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(Box::new(fsm)); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(read) => { + self.tape_or_fsm = TapeOrFsm::Tape(read); + } + } + } else { + self.tape_or_fsm = TapeOrFsm::Tape(tape); + break; + } + } + } + } + + // Check if we're completely finished or need more data + match (self.tape_or_fsm, self.command_tape.is_finished()) { + (TapeOrFsm::Tape(read), true) => Ok(StateMachineControlFlow::Done(read)), + (tape_or_fsm, _) => { + self.tape_or_fsm = tape_or_fsm; + Ok(StateMachineControlFlow::NeedMore(self)) + } + } + } +} diff --git a/avro/src/state_machines/reading/error.rs b/avro/src/state_machines/reading/error.rs new file mode 100644 index 00000000..12bcefd8 --- /dev/null +++ b/avro/src/state_machines/reading/error.rs @@ -0,0 +1,16 @@ +use crate::{Schema, state_machines::reading::ItemRead}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ValueFromTapeError { + #[error("Unexpected end of tape while building Value")] + UnexpectedEndOfTape, + #[error( + "Mismatch between tape and schema while building Value: schema {schema}, tape: {item:?}" + )] + TapeSchemaMismatch { schema: Schema, item: ItemRead }, + #[error( + "Mismatch between tape and schema while building Value: Schema::Fixed expected {expected} bytes, but tape had {actual}" + )] + TapeSchemaMismatchFixed { expected: usize, actual: usize }, +} diff --git a/avro/src/state_machines/reading/mod.rs b/avro/src/state_machines/reading/mod.rs new file mode 100644 index 00000000..465642ac --- /dev/null +++ b/avro/src/state_machines/reading/mod.rs @@ -0,0 +1,1151 @@ +use crate::{ + Decimal, Duration, Error, Schema, + bigdecimal::deserialize_big_decimal, + error::Details, + schema::{ + ArraySchema, EnumSchema, FixedSchema, MapSchema, Name, Names, Namespace, RecordSchema, + ResolvedSchema, UnionSchema, + }, + state_machines::reading::{ + block::BlockStateMachine, bytes::BytesStateMachine, commands::CommandTape, + datum::DatumStateMachine, error::ValueFromTapeError, union::UnionStateMachine, + }, + types::Value, + util::decode_variable, +}; +use oval::Buffer; +use serde::Deserialize; +use std::{borrow::Borrow, collections::HashMap, io::Read, ops::Deref, str::FromStr}; +use uuid::Uuid; + +pub mod async_impl; +pub mod block; +pub mod bytes; +pub mod codec; +mod commands; +pub mod datum; +pub mod error; +mod object_container_file; +pub mod sync; +mod union; + +pub trait StateMachine: Sized { + type Output: Sized; + + /// Start/continue the state machine. + /// + /// Implementers are not allowed to return until they can't make progress anymore. + fn parse(self, buffer: &mut Buffer) -> StateMachineResult; +} + +/// Indicates whether the state machine has completed or needs to be polled again. +#[must_use] +pub enum StateMachineControlFlow { + /// The state machine needs more data before it can continue. + NeedMore(StateMachine), + /// The state machine is done and the result is returned.s + Done(Output), +} + +pub type StateMachineResult = + Result, Error>; + +/// The sub state machine that is currently being driven. +/// +/// The `Int`, `Long`, `Float`, `Double`, and `Enum` statemachines don't have state, as +/// they don't consume the buffer if there are not enough bytes. This means that the only +/// thing these statemachines are keeping track of is which type we're actually decoding. +pub enum SubStateMachine { + Null(Vec), + Bool(Vec), + Int(Vec), + Long(Vec), + Float(Vec), + Double(Vec), + Enum(Vec), + Bytes { + fsm: BytesStateMachine, + read: Vec, + }, + String { + fsm: BytesStateMachine, + read: Vec, + }, + Block(BlockStateMachine), + Object(DatumStateMachine), + Union(UnionStateMachine), +} + +impl StateMachine for SubStateMachine { + type Output = Vec; + + fn parse(self, buffer: &mut Buffer) -> StateMachineResult { + match self { + SubStateMachine::Null(mut read) => { + read.push(ItemRead::Null); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Bool(mut read) => { + let mut byte = [0; 1]; + buffer + .read_exact(&mut byte) + .expect("Unreachable! Buffer is not empty"); + match byte { + [0] => read.push(ItemRead::Boolean(false)), + [1] => read.push(ItemRead::Boolean(true)), + [byte] => return Err(Details::BoolValue(byte).into()), + } + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Int(mut read) => { + let Some(n) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Int(read))); + }; + let n = i32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + read.push(ItemRead::Int(n)); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Long(mut read) => { + let Some(n) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Long(read))); + }; + read.push(ItemRead::Long(n)); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Float(mut read) => { + let Some(bytes) = buffer.data().first_chunk().copied() else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Float(read))); + }; + buffer.consume(4); + read.push(ItemRead::Float(f32::from_le_bytes(bytes))); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Double(mut read) => { + let Some(bytes) = buffer.data().first_chunk().copied() else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Double(read))); + }; + buffer.consume(8); + read.push(ItemRead::Double(f64::from_le_bytes(bytes))); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Enum(mut read) => { + let Some(n) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(Self::Enum(read))); + }; + // TODO: Wrong error + let n = u32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + read.push(ItemRead::Enum(n)); + Ok(StateMachineControlFlow::Done(read)) + } + SubStateMachine::Bytes { fsm, mut read } => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::Bytes { fsm, read })) + } + StateMachineControlFlow::Done(bytes) => { + read.push(ItemRead::Bytes(bytes)); + Ok(StateMachineControlFlow::Done(read)) + } + }, + SubStateMachine::String { fsm, mut read } => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::String { + fsm, + read, + })) + } + StateMachineControlFlow::Done(bytes) => { + let string = String::from_utf8(bytes).map_err(Details::ConvertToUtf8)?; + read.push(ItemRead::String(string)); + Ok(StateMachineControlFlow::Done(read)) + } + }, + SubStateMachine::Block(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::Block(fsm))) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + SubStateMachine::Union(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::Union(fsm))) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + SubStateMachine::Object(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + Ok(StateMachineControlFlow::NeedMore(Self::Object(fsm))) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + } + } +} + +/// A item that was read from the document. +#[derive(Debug)] +#[must_use] +pub enum ItemRead { + Null, + Boolean(bool), + Int(i32), + Long(i64), + Float(f32), + Double(f64), + // TODO: smollvec/hipbytes? + Bytes(Vec), + // TODO: smollstr/hipstr? + String(String), + /// The variant of the Enum that was read. + Enum(u32), + /// The variant of the Union that was read. + /// + /// The variant data is next. + Union(u32), + /// The start of a block of a Map or Array. + Block(usize), +} + +/// Read a zigzagged varint from the buffer. +/// +/// Will only consume the buffer if a whole number has been read. +/// If insufficient bytes are available it will return `Ok(None)` to +/// indicate it needs more bytes. +pub fn decode_zigzag_buffer(buffer: &mut Buffer) -> Result, Error> { + if let Some((decoded, consumed)) = decode_variable(buffer.data())? { + buffer.consume(consumed); + Ok(Some(decoded)) + } else { + Ok(None) + } +} + +/// Deserialize a tape to a [`Value`] using the provided [`Schema`]. +/// +/// The schema must be compatible with the schema used by the original writer. +/// +/// Both `names` and `extra_names` are checked when a [`Schema::Ref`] is encountered. They're allowed +/// to have overlapping items. +/// +/// # Panics +/// Can panic if the provided schema does not exactly match the schema used to create the tape. To +/// convert between the writer and reader schema use [`Value::resolve`] instead. +pub fn value_from_tape( + tape: &mut Vec, + schema: &Schema, + names: &Names, +) -> Result { + value_from_tape_internal(&mut tape.drain(..), schema, &None, names) +} + +/// Recursively transform the `tape` into a [`Value`] according to the provided [`Schema`]. +/// +/// Both `names` and `extra_names` are checked when a [`Schema::Ref`] is encountered. They're allowed +/// to have overlapping items. +pub fn value_from_tape_internal( + tape: &mut impl Iterator, + schema: &Schema, + enclosing_namespace: &Namespace, + names: &Names, +) -> Result { + match schema { + Schema::Null => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Null => Ok(Value::Null), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Boolean => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Boolean(bool) => Ok(Value::Boolean(bool)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Int => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Int(bool) => Ok(Value::Int(bool)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Long => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::Long(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Float => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Float(float) => Ok(Value::Float(float)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Double => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Double(double) => Ok(Value::Double(double)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Bytes => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(bytes) => Ok(Value::Bytes(bytes)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::String => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::String(string) => Ok(Value::String(string)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Array(ArraySchema { items, .. }) => { + let mut collected = Vec::new(); + loop { + let n = match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Block(n) => Ok(n), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + }), + }?; + if n == 0 { + break; + } + collected.reserve(n); + for _ in 0..n { + collected.push(value_from_tape_internal( + tape, + items, + enclosing_namespace, + names, + )?); + } + } + Ok(Value::Array(collected)) + } + Schema::Map(MapSchema { types, .. }) => { + let mut collected = HashMap::new(); + loop { + let n = match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Block(n) => Ok(n), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + }), + }?; + if n == 0 { + break; + } + collected.reserve(n); + for _ in 0..n { + let key = match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::String(string) => Ok(string), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: Schema::String, + item, + }), + }?; + let val = value_from_tape_internal(tape, types, enclosing_namespace, names)?; + collected.insert(key, val); + } + } + Ok(Value::Map(collected)) + } + Schema::Union(UnionSchema { schemas, .. }) => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Union(variant) => { + let schema = schemas.get(usize::try_from(variant).unwrap()).ok_or( + Details::GetUnionVariant { + index: variant as i64, + num_variants: schemas.len(), + }, + )?; + let value = Box::new(value_from_tape_internal( + tape, + schema, + enclosing_namespace, + names, + )?); + Ok(Value::Union(variant, value)) + } + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::Record(RecordSchema { name, fields, .. }) => { + let fqn = name.fully_qualified_name(enclosing_namespace); + let mut collected = Vec::with_capacity(fields.len()); + for field in fields { + let collect = value_from_tape_internal(tape, &field.schema, &fqn.namespace, names)?; + collected.push((field.name.clone(), collect)); + } + Ok(Value::Record(collected)) + } + Schema::Enum(EnumSchema { symbols, .. }) => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Enum(val) => Ok(Value::Enum( + val, + symbols.get(usize::try_from(val).unwrap()).unwrap().clone(), + )), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::Fixed(FixedSchema { size, .. }) => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(fixed) => { + if *size == fixed.len() { + Ok(Value::Fixed(fixed.len(), fixed)) + } else { + Err(ValueFromTapeError::TapeSchemaMismatchFixed { + expected: *size, + actual: fixed.len(), + } + .into()) + } + } + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::Decimal(_) => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(bytes) => Ok(Value::Decimal(Decimal::from(&bytes))), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::BigDecimal => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(bytes) => deserialize_big_decimal(&bytes).map(Value::BigDecimal), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Uuid => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::String(string) => Uuid::from_str(&string) + .map(Value::Uuid) + .map_err(|e| Details::ConvertStrToUuid(e).into()), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Date => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Int(int) => Ok(Value::Date(int)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::TimeMillis => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Int(int) => Ok(Value::TimeMillis(int)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::TimeMicros => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::TimeMicros(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::TimestampMillis => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::TimestampMillis(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::TimestampMicros => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::TimestampMicros(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::TimestampNanos => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::TimestampNanos(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::LocalTimestampMillis => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::LocalTimestampMillis(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::LocalTimestampMicros => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::LocalTimestampMicros(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::LocalTimestampNanos => { + match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Long(long) => Ok(Value::LocalTimestampNanos(long)), + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + } + } + Schema::Duration => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + ItemRead::Bytes(bytes) => { + let array: [u8; 12] = bytes.deref().try_into().unwrap(); + Ok(Value::Duration(Duration::from(array))) + } + item => Err(ValueFromTapeError::TapeSchemaMismatch { + schema: schema.clone(), + item, + } + .into()), + }, + Schema::Ref { name } => { + let fqn = name.fully_qualified_name(enclosing_namespace); + if let Some(resolved) = names.get(&fqn) { + value_from_tape_internal(tape, resolved, &fqn.namespace, names) + } else { + Err(Details::SchemaResolutionError(fqn).into()) + } + } + } +} + +/// Deserialize a tape to `T` using the provided [`Schema`]. +/// +/// The schema must be compatible with the schema used by the original writer. +pub fn deserialize_from_tape<'a, T: Deserialize<'a>>( + tape: &mut Vec, + schema: &Schema, +) -> Result { + let rs = ResolvedSchema::try_from(schema)?; + deserialize_from_tape_internal(tape, schema, rs.get_names(), &None) +} + +/// Recursively transform the `tape` into a `T` according to the provided [`Schema`]. +fn deserialize_from_tape_internal<'a, T: Deserialize<'a>, S: Borrow>( + tape: &mut Vec, + _schema: &Schema, + _names: &HashMap, + _enclosing_namespace: &Namespace, +) -> Result { + tape.clear(); + todo!() +} + +#[cfg(test)] +#[allow(clippy::expect_fun_call)] +mod tests { + use crate::{ + Decimal, + encode::{encode, tests::success}, + from_avro_datum, + schema::{DecimalSchema, FixedSchema, Schema}, + types::{ + Value, + Value::{Array, Int, Map}, + }, + }; + use apache_avro_test_helper::TestResult; + use pretty_assertions::assert_eq; + use std::collections::HashMap; + use uuid::Uuid; + + #[test] + fn test_decode_array_without_size() -> TestResult { + let mut input: &[u8] = &[6, 2, 4, 6, 0]; + + let result = from_avro_datum(&Schema::array(Schema::Int), &mut input, None)?; + + assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result); + + Ok(()) + } + + #[test] + fn test_decode_array_with_size() -> TestResult { + let mut input: &[u8] = &[5, 6, 2, 4, 6, 0]; + let result = from_avro_datum(&Schema::array(Schema::Int), &mut input, None)?; + assert_eq!(Array(vec!(Int(1), Int(2), Int(3))), result); + + Ok(()) + } + + #[test] + fn test_decode_map_without_size() -> TestResult { + let mut input: &[u8] = &[0x02, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; + let result = from_avro_datum(&Schema::map(Schema::Int), &mut input, None)?; + let mut expected = HashMap::new(); + expected.insert(String::from("test"), Int(1)); + assert_eq!(Map(expected), result); + + Ok(()) + } + + #[test] + fn test_decode_map_with_size() -> TestResult { + let mut input: &[u8] = &[0x01, 0x0C, 0x08, 0x74, 0x65, 0x73, 0x74, 0x02, 0x00]; + let result = from_avro_datum(&Schema::map(Schema::Int), &mut input, None)?; + let mut expected = HashMap::new(); + expected.insert(String::from("test"), Int(1)); + assert_eq!(Map(expected), result); + + Ok(()) + } + + #[test] + fn test_negative_decimal_value() -> TestResult { + use crate::{encode::encode, schema::Name}; + use num_bigint::ToBigInt; + let inner = Box::new(Schema::Fixed( + FixedSchema::builder() + .name(Name::new("decimal")?) + .size(2) + .build(), + )); + let schema = Schema::Decimal(DecimalSchema { + inner, + precision: 4, + scale: 2, + }); + let bigint = (-423).to_bigint().unwrap(); + let value = Value::Decimal(Decimal::from(bigint.to_signed_bytes_be())); + + let mut buffer = Vec::new(); + encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); + + let mut bytes = &buffer[..]; + let result = from_avro_datum(&schema, &mut bytes, None)?; + assert_eq!(result, value); + + Ok(()) + } + + #[test] + fn test_decode_decimal_with_bigger_than_necessary_size() -> TestResult { + use crate::{encode::encode, schema::Name}; + use num_bigint::ToBigInt; + let inner = Box::new(Schema::Fixed(FixedSchema { + size: 13, + name: Name::new("decimal")?, + aliases: None, + doc: None, + default: None, + attributes: Default::default(), + })); + let schema = Schema::Decimal(DecimalSchema { + inner, + precision: 4, + scale: 2, + }); + let value = Value::Decimal(Decimal::from( + ((-423).to_bigint().unwrap()).to_signed_bytes_be(), + )); + let mut buffer = Vec::::new(); + + encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); + let mut bytes: &[u8] = &buffer[..]; + let result = from_avro_datum(&schema, &mut bytes, None)?; + assert_eq!(result, value); + + Ok(()) + } + + #[test] + fn test_avro_3448_recursive_definition_decode_union() -> TestResult { + // if encoding fails in this test check the corresponding test in encode + let schema = Schema::parse_str( + r#" + { + "type":"record", + "name":"TestStruct", + "fields": [ + { + "name":"a", + "type":[ "null", { + "type":"record", + "name": "Inner", + "fields": [ { + "name":"z", + "type":"int" + }] + }] + }, + { + "name":"b", + "type":"Inner" + } + ] + }"#, + )?; + + let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); + let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); + let outer_value1 = Value::Record(vec![ + ("a".into(), Value::Union(1, Box::new(inner_value1))), + ("b".into(), inner_value2.clone()), + ]); + let mut buf = Vec::new(); + encode(&outer_value1, &schema, &mut buf).expect(&success(&outer_value1, &schema)); + assert!(!buf.is_empty()); + let mut bytes = &buf[..]; + assert_eq!( + outer_value1, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to decode using recursive definitions with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + let outer_value2 = Value::Record(vec![ + ("a".into(), Value::Union(0, Box::new(Value::Null))), + ("b".into(), inner_value2), + ]); + encode(&outer_value2, &schema, &mut buf).expect(&success(&outer_value2, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_value2, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to decode using recursive definitions with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn test_avro_3448_recursive_definition_decode_array() -> TestResult { + let schema = Schema::parse_str( + r#" + { + "type":"record", + "name":"TestStruct", + "fields": [ + { + "name":"a", + "type":{ + "type":"array", + "items": { + "type":"record", + "name": "Inner", + "fields": [ { + "name":"z", + "type":"int" + }] + } + } + }, + { + "name":"b", + "type": "Inner" + } + ] + }"#, + )?; + + let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); + let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); + let outer_value = Value::Record(vec![ + ("a".into(), Value::Array(vec![inner_value1])), + ("b".into(), inner_value2), + ]); + let mut buf = Vec::new(); + encode(&outer_value, &schema, &mut buf).expect(&success(&outer_value, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_value, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to decode using recursive definitions with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn test_avro_3448_recursive_definition_decode_map() -> TestResult { + let schema = Schema::parse_str( + r#" + { + "type":"record", + "name":"TestStruct", + "fields": [ + { + "name":"a", + "type":{ + "type":"map", + "values": { + "type":"record", + "name": "Inner", + "fields": [ { + "name":"z", + "type":"int" + }] + } + } + }, + { + "name":"b", + "type": "Inner" + } + ] + }"#, + )?; + + let inner_value1 = Value::Record(vec![("z".into(), Value::Int(3))]); + let inner_value2 = Value::Record(vec![("z".into(), Value::Int(6))]); + let outer_value = Value::Record(vec![ + ( + "a".into(), + Value::Map(vec![("akey".into(), inner_value1)].into_iter().collect()), + ), + ("b".into(), inner_value2), + ]); + let mut buf = Vec::new(); + encode(&outer_value, &schema, &mut buf).expect(&success(&outer_value, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_value, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to decode using recursive definitions with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn test_avro_3448_proper_multi_level_decoding_middle_namespace() -> TestResult { + // if encoding fails in this test check the corresponding test in encode + let schema = r#" + { + "name": "record_name", + "namespace": "space", + "type": "record", + "fields": [ + { + "name": "outer_field_1", + "type": [ + "null", + { + "type": "record", + "name": "middle_record_name", + "namespace":"middle_namespace", + "fields":[ + { + "name":"middle_field_1", + "type":[ + "null", + { + "type":"record", + "name":"inner_record_name", + "fields":[ + { + "name":"inner_field_1", + "type":"double" + } + ] + } + ] + } + ] + } + ] + }, + { + "name": "outer_field_2", + "type" : "middle_namespace.inner_record_name" + } + ] + } + "#; + let schema = Schema::parse_str(schema)?; + let inner_record = Value::Record(vec![("inner_field_1".into(), Value::Double(5.4))]); + let middle_record_variation_1 = Value::Record(vec![( + "middle_field_1".into(), + Value::Union(0, Box::new(Value::Null)), + )]); + let middle_record_variation_2 = Value::Record(vec![( + "middle_field_1".into(), + Value::Union(1, Box::new(inner_record.clone())), + )]); + let outer_record_variation_1 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(0, Box::new(Value::Null)), + ), + ("outer_field_2".into(), inner_record.clone()), + ]); + let outer_record_variation_2 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(1, Box::new(middle_record_variation_1)), + ), + ("outer_field_2".into(), inner_record.clone()), + ]); + let outer_record_variation_3 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(1, Box::new(middle_record_variation_2)), + ), + ("outer_field_2".into(), inner_record), + ]); + + let mut buf = Vec::new(); + encode(&outer_record_variation_1, &schema, &mut buf) + .expect(&success(&outer_record_variation_1, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_1, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + encode(&outer_record_variation_2, &schema, &mut buf) + .expect(&success(&outer_record_variation_2, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_2, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + encode(&outer_record_variation_3, &schema, &mut buf) + .expect(&success(&outer_record_variation_3, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_3, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn test_avro_3448_proper_multi_level_decoding_inner_namespace() -> TestResult { + // if encoding fails in this test check the corresponding test in encode + let schema = r#" + { + "name": "record_name", + "namespace": "space", + "type": "record", + "fields": [ + { + "name": "outer_field_1", + "type": [ + "null", + { + "type": "record", + "name": "middle_record_name", + "namespace":"middle_namespace", + "fields":[ + { + "name":"middle_field_1", + "type":[ + "null", + { + "type":"record", + "name":"inner_record_name", + "namespace":"inner_namespace", + "fields":[ + { + "name":"inner_field_1", + "type":"double" + } + ] + } + ] + } + ] + } + ] + }, + { + "name": "outer_field_2", + "type" : "inner_namespace.inner_record_name" + } + ] + } + "#; + let schema = Schema::parse_str(schema)?; + let inner_record = Value::Record(vec![("inner_field_1".into(), Value::Double(5.4))]); + let middle_record_variation_1 = Value::Record(vec![( + "middle_field_1".into(), + Value::Union(0, Box::new(Value::Null)), + )]); + let middle_record_variation_2 = Value::Record(vec![( + "middle_field_1".into(), + Value::Union(1, Box::new(inner_record.clone())), + )]); + let outer_record_variation_1 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(0, Box::new(Value::Null)), + ), + ("outer_field_2".into(), inner_record.clone()), + ]); + let outer_record_variation_2 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(1, Box::new(middle_record_variation_1)), + ), + ("outer_field_2".into(), inner_record.clone()), + ]); + let outer_record_variation_3 = Value::Record(vec![ + ( + "outer_field_1".into(), + Value::Union(1, Box::new(middle_record_variation_2)), + ), + ("outer_field_2".into(), inner_record), + ]); + + let mut buf = Vec::new(); + encode(&outer_record_variation_1, &schema, &mut buf) + .expect(&success(&outer_record_variation_1, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_1, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + encode(&outer_record_variation_2, &schema, &mut buf) + .expect(&success(&outer_record_variation_2, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_2, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + let mut buf = Vec::new(); + encode(&outer_record_variation_3, &schema, &mut buf) + .expect(&success(&outer_record_variation_3, &schema)); + let mut bytes = &buf[..]; + assert_eq!( + outer_record_variation_3, + from_avro_datum(&schema, &mut bytes, None).expect(&format!( + "Failed to Decode with recursively defined namespace with schema:\n {:?}\n", + &schema + )) + ); + + Ok(()) + } + + #[test] + fn avro_3926_encode_decode_uuid_to_string() -> TestResult { + use crate::encode::encode; + + let schema = Schema::String; + let value = Value::Uuid(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000")?); + + let mut buffer = Vec::new(); + encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); + + let result = from_avro_datum(&Schema::Uuid, &mut &buffer[..], None)?; + assert_eq!(result, value); + + Ok(()) + } + + // TODO: Schema::Uuid needs a sub schema which is either String or Fixed. It's now part of the + // spec anyway. + // #[test] + // fn avro_3926_encode_decode_uuid_to_fixed() -> TestResult { + // use crate::encode::encode; + // + // let schema = Schema::Fixed(FixedSchema { + // size: 16, + // name: "uuid".into(), + // aliases: None, + // doc: None, + // default: None, + // attributes: Default::default(), + // }); + // let value = Value::Uuid(Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000")?); + // + // let mut buffer = Vec::new(); + // encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); + // + // let result = from_avro_datum(&Schema::Uuid, &mut &buffer[..], None)?; + // assert_eq!(result, value); + // + // Ok(()) + // } +} diff --git a/avro/src/state_machines/reading/object_container_file.rs b/avro/src/state_machines/reading/object_container_file.rs new file mode 100644 index 00000000..8f9c2bd5 --- /dev/null +++ b/avro/src/state_machines/reading/object_container_file.rs @@ -0,0 +1,318 @@ +use crate::{ + Codec, Error, Schema, + error::Details, + schema::{Names, ResolvedSchema, resolve_names, resolve_names_with_schemata}, + state_machines::reading::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, + codec::CodecStateMachine, datum::DatumStateMachine, decode_zigzag_buffer, + }, +}; +use log::warn; +use oval::Buffer; +use serde_json::Value; +use std::{collections::HashMap, io::Read, str::FromStr, sync::Arc}; + +// TODO: Dynamically/const construct this, this one works only on 64-bit LE +/// The tape corresponding to [`HEADER_JSON`]. +/// +/// ```json +/// { +/// "type": "record", +/// "name": "org.apache.avro.file.HeaderNoMagic", +/// "fields": [ +/// {"name": "meta", "type": {"type": "map", "values": "bytes"}}, +/// {"name": "sync", "type": {"type": "fixed", "name": "Sync", "size": 16}} +/// ] +/// } +/// ``` +#[rustfmt::skip] +const HEADER_TAPE: &[u8] = &[ + CommandTape::BLOCK | 2 << 4, // Starts with a map + CommandTape::STRING, // The keys are strings + CommandTape::BYTES, // The values are bytes + CommandTape::FIXED, // After the map there is a Fixed amount of bytes + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // The amount of bytes is 0x0F +]; +#[cfg(test)] +const HEADER_JSON: &str = r#"{"type": "record","name": "org.apache.avro.file.HeaderNoMagic","fields": [{"name": "meta", "type": {"type": "map", "values": "bytes"}},{"name": "sync", "type": {"type": "fixed", "name": "Sync", "size": 16}}]}"#; + +/// The header as read from an Object Container file format. +pub struct ObjectContainerFileHeader { + /// The schema used to write the file. + pub schema: Schema, + pub names: Names, + /// The compression used. + pub codec: Codec, + /// The sync marker used between blocks + pub sync: [u8; 16], + /// User metadata in the header + pub metadata: HashMap>, +} + +impl ObjectContainerFileHeader { + pub fn command_tape() -> CommandTape { + CommandTape::new(Arc::from(HEADER_TAPE)) + } + + /// Create the header from an output tape. + /// + /// # Panics + /// Will panic if the tape was not produced from [`Self::command_tape()`]. + pub fn from_tape(mut tape: Vec, mut schemata: Vec<&Schema>) -> Result { + // We want to read the tape from front to back + let mut tape = tape.drain(..); + + let mut schema = None; + let mut codec = None; + let mut found_compression_level = false; + let mut metadata = HashMap::new(); + let mut names = HashMap::new(); + + while let Some(ItemRead::Block(items_left)) = tape.next() { + if items_left == 0 { + // Got to the end of the map + break; + } + for _ in 0..items_left { + let Some(ItemRead::String(key)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + let Some(ItemRead::Bytes(value)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + + match key.as_ref() { + "avro.schema" => { + if schema.is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + let json: Value = + serde_json::from_slice(&value).map_err(Details::ParseSchemaJson)?; + + if !schemata.is_empty() { + // TODO: Make parse_with_names accept NamesRef + let schemata = std::mem::take(&mut schemata); + resolve_names_with_schemata(&schemata, &mut names, &None)?; + + // TODO: Maybe we can not do this, and just past &names to Schema::parse_with_names + let rs = ResolvedSchema::try_from(schemata)?; + let names: Names = rs + .get_names() + .iter() + .map(|(name, &schema)| (name.clone(), schema.clone())) + .collect(); + + let parsed_schema = Schema::parse_with_names(&json, names)?; + schema.replace(parsed_schema); + } else { + let parsed_schema = Schema::parse(&json)?; + resolve_names(&parsed_schema, &mut names, &None)?; + schema.replace(parsed_schema); + } + } + "avro.codec" => { + let string = String::from_utf8(value).map_err(Details::ConvertToUtf8)?; + let parsed_codec = Codec::from_str(&string) + .map_err(|_| Details::CodecNotSupported(string))?; + if codec.replace(parsed_codec).is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + } + "avro.codec.compression_level" => { + // Compression level is not useful for decoding + if found_compression_level { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + found_compression_level = true; + } + _ => { + if key.starts_with("avro.") { + warn!("Ignoring unknown metadata key: {key}"); + } + if metadata.insert(key, value).is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + } + } + } + } + let Some(schema) = schema else { + return Err(Details::GetHeaderMetadata.into()); + }; + let codec = codec.unwrap_or(Codec::Null); + let Some(ItemRead::Bytes(raw_sync)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + let sync = raw_sync + .as_slice() + .try_into() + .expect("The input does not correspond to the command tape"); + Ok(ObjectContainerFileHeader { + schema, + names, + codec, + sync, + metadata, + }) + } +} + +/// A state machine for parsing the header of the Object Container file format. +/// +/// After finishing this state machine the body can be read with [`ObjectContainerFileBodyStateMachine`]. +pub struct ObjectContainerFileHeaderStateMachine<'a> { + /// The actual state machine used to parse the header. + /// + /// This doesn't actually need to be an [`Option`] as it's constructed in [`Self::new`]. However, + /// as [`StateMachine::parse`] takes `self` we need it in an `Option` so we can do [`Option::take`]. + fsm: Option, + read_magic: bool, + schemata: Vec<&'a Schema>, +} + +impl<'a> ObjectContainerFileHeaderStateMachine<'a> { + pub fn new(schemata: Vec<&'a Schema>) -> Self { + let commands = CommandTape::new(Arc::from(HEADER_TAPE)); + Self { + fsm: Some(DatumStateMachine::new(commands)), + read_magic: false, + schemata, + } + } +} + +impl StateMachine for ObjectContainerFileHeaderStateMachine<'_> { + type Output = ObjectContainerFileHeader; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + while !self.read_magic { + if buffer.available_data() < 4 { + return Ok(StateMachineControlFlow::NeedMore(self)); + } + if buffer.data()[0..4] != [b'O', b'b', b'j', 1] { + return Err(Details::HeaderMagic.into()); + } + buffer.consume(4); + self.read_magic = true; + } + match self.fsm.take().expect("Unreachable!").parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + let _ = self.fsm.insert(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(tape) => Ok(StateMachineControlFlow::Done( + ObjectContainerFileHeader::from_tape(tape, self.schemata)?, + )), + } + } +} + +pub struct ObjectContainerFileBodyStateMachine { + fsm: Option>, + tape: CommandTape, + sync: [u8; 16], + left_in_block: usize, + need_to_read_block_byte_size: bool, + need_to_read_sync: bool, +} + +impl ObjectContainerFileBodyStateMachine { + pub fn new(tape: CommandTape, sync: [u8; 16], codec: Codec) -> Self { + Self { + fsm: Some(CodecStateMachine::new( + DatumStateMachine::new(tape.clone()), + codec, + )), + tape, + sync, + left_in_block: 0, + need_to_read_block_byte_size: false, + need_to_read_sync: false, + } + } +} + +impl StateMachine for ObjectContainerFileBodyStateMachine { + type Output = Option<(Vec, Self)>; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + if self.left_in_block == 0 { + if self.need_to_read_sync { + if buffer.available_data() < 16 { + return Ok(StateMachineControlFlow::NeedMore(self)); + } + let mut sync = [0; 16]; + assert_eq!( + buffer.read(&mut sync).expect("Unreachable!"), + 16, + "Did not read enough data!" + ); + if sync != self.sync { + return Err(Details::GetBlockMarker.into()); + } + self.need_to_read_sync = false; + } + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let abs_block = block.unsigned_abs(); + let abs_block = + usize::try_from(abs_block).map_err(|e| Details::ConvertU64ToUsize(e, abs_block))?; + if abs_block == 0 { + // Done parsing the array + return Ok(StateMachineControlFlow::Done(None)); + } + self.need_to_read_block_byte_size = true; + // This will only be done after this block is finished + self.need_to_read_sync = true; + self.left_in_block = abs_block; + } + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + // Make sure the value is sane + let _size = usize::try_from(block).map_err(|e| Details::ConvertI64ToUsize(e, block))?; + self.need_to_read_block_byte_size = false; + } + + match self.fsm.take().expect("Unreachable!").parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.fsm.replace(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done((result, mut codec)) => { + codec.reset(DatumStateMachine::new(self.tape.clone())); + self.fsm.replace(codec); + self.left_in_block -= 1; + Ok(StateMachineControlFlow::Done(Some((result, self)))) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Arc}; + + use crate::{ + Schema, + state_machines::reading::{ + commands::CommandTape, + object_container_file::{HEADER_JSON, HEADER_TAPE}, + }, + }; + + #[test] + pub fn header_tape() { + let schema = Schema::parse_str(HEADER_JSON).unwrap(); + let tape = CommandTape::build_from_schema(&schema, &HashMap::new()).unwrap(); + assert_eq!(tape, CommandTape::new(Arc::from(HEADER_TAPE))); + } +} diff --git a/avro/src/state_machines/reading/sync.rs b/avro/src/state_machines/reading/sync.rs new file mode 100644 index 00000000..daab82a0 --- /dev/null +++ b/avro/src/state_machines/reading/sync.rs @@ -0,0 +1,351 @@ +use oval::Buffer; +use serde::Deserialize; +use std::{collections::HashMap, io::Read}; + +use crate::{ + AvroResult, Error, Schema, + error::Details, + schema::{resolve_names, resolve_names_with_schemata}, + state_machines::reading::{ + ItemRead, StateMachine, StateMachineControlFlow, + commands::CommandTape, + datum::DatumStateMachine, + deserialize_from_tape, + object_container_file::{ + ObjectContainerFileBodyStateMachine, ObjectContainerFileHeader, + ObjectContainerFileHeaderStateMachine, + }, + value_from_tape, + }, + types::Value, +}; + +/// Main interface for reading Avro formatted values. +/// +/// To be used as an iterator: +/// +/// ```no_run +/// # use apache_avro::Reader; +/// # use std::io::Cursor; +/// # let input = Cursor::new(Vec::::new()); +/// for value in Reader::new(input).unwrap() { +/// match value { +/// Ok(v) => println!("{:?}", v), +/// Err(e) => println!("Error: {}", e), +/// }; +/// } +/// ``` +pub struct Reader<'a, R> { + reader_schema: Option<&'a Schema>, + header: ObjectContainerFileHeader, + fsm: Option, + reader: R, + buffer: Buffer, +} + +impl<'a, R: Read> Reader<'a, R> { + /// Creates a [`Reader`] that will use the schema from the file header. + /// + /// No reader [`Schema`] will be set. + /// + /// **NOTE** The Avro header is going to be read automatically upon creation of the [`Reader`]. + pub fn new(reader: R) -> Result { + Self::new_inner(reader, None, Vec::new()) + } + + /// Creates a [`Reader`] that will use the given schema for schema resolution. + /// + /// **NOTE** The Avro header is going to be read automatically upon creation of the [`Reader`]. + pub fn with_schema(schema: &'a Schema, reader: R) -> Result { + Self::new_inner(reader, Some(schema), Vec::new()) + } + + /// Creates a [`Reader`] that will use the given schema for schema resolution. + /// + /// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be + /// resolved and an error will be returned. + /// + /// Any [`Schema::Ref`] will be resolved using the schemata. + /// + /// **NOTE** The avro header is going to be read automatically upon creation of the [`Reader`]. + pub fn with_schemata( + schema: &'a Schema, + schemata: Vec<&'a Schema>, + reader: R, + ) -> Result { + Self::new_inner(reader, Some(schema), schemata) + } + + /// Get a reference to the writer [`Schema`]. + pub fn writer_schema(&self) -> &Schema { + &self.header.schema + } + + /// Get a reference to the optional reader [`Schema`]. + /// + /// This will only be set if there was a reader schema provided *and* it differed from the + /// writer schema. + pub fn reader_schema(&self) -> Option<&'a Schema> { + self.reader_schema + } + + /// Get a reference to the user metadata. + pub fn user_metadata(&self) -> &HashMap> { + &self.header.metadata + } + + /// Get a reference to the file header. + pub fn header(&self) -> &ObjectContainerFileHeader { + &self.header + } + + fn new_inner( + mut reader: R, + reader_schema: Option<&'a Schema>, + schemata: Vec<&'a Schema>, + ) -> Result { + // Read a maximum of 2Kb per read + let mut buffer = Buffer::with_capacity(2 * 1024); + + // Parse the header + let mut fsm = ObjectContainerFileHeaderStateMachine::new(schemata); + let header = loop { + // Fill the buffer + let n = reader.read(buffer.space()).map_err(Details::ReadHeader)?; + if n == 0 { + return Err(Details::ReadHeader(std::io::ErrorKind::UnexpectedEof.into()).into()); + } + buffer.fill(n); + + // Start/continue the state machine + match fsm.parse(&mut buffer)? { + StateMachineControlFlow::NeedMore(new_fsm) => fsm = new_fsm, + StateMachineControlFlow::Done(header) => break header, + } + }; + + let tape = CommandTape::build_from_schema(&header.schema, &header.names)?; + + let reader_schema = if let Some(schema) = reader_schema + && schema != &header.schema + { + Some(schema) + } else { + None + }; + + Ok(Self { + reader_schema, + fsm: Some(ObjectContainerFileBodyStateMachine::new( + tape, + header.sync, + header.codec, + )), + header, + reader, + buffer, + }) + } + + /// Get the next object in the file + fn next_object(&mut self) -> Option, Error>> { + if let Some(mut fsm) = self.fsm.take() { + loop { + match fsm.parse(&mut self.buffer) { + Ok(StateMachineControlFlow::NeedMore(new_fsm)) => { + fsm = new_fsm; + let n = match self.reader.read(self.buffer.space()) { + Ok(0) => { + return Some(Err(Details::ReadIntoBuf( + std::io::ErrorKind::UnexpectedEof.into(), + ) + .into())); + } + Ok(n) => n, + Err(e) => return Some(Err(Details::ReadIntoBuf(e).into())), + }; + self.buffer.fill(n); + } + Ok(StateMachineControlFlow::Done(Some((object, fsm)))) => { + self.fsm.replace(fsm); + return Some(Ok(object)); + } + Ok(StateMachineControlFlow::Done(None)) => { + return None; + } + Err(e) => { + return Some(Err(e)); + } + } + } + } + None + } + + /// Deserialize the next object directly to `T`. + /// + /// This function goes immediately from the inner representation to `T` without going through + /// [`Value`] first. It does not support schema resolution using a reader [`Schema`]. + /// + /// # Panics + /// Will panic if a reader [`Schema`] was supplied when creating the [`Reader`]. + pub fn next_serde<'b, T: Deserialize<'b>>(&mut self) -> Option> { + assert!( + self.reader_schema.is_none(), + "Schema resolution is not supported with this function!" + ); + self.next_object() + .map(|r| r.and_then(|mut tape| deserialize_from_tape(&mut tape, &self.header.schema))) + } +} + +impl Iterator for Reader<'_, R> { + type Item = Result; + + fn next(&mut self) -> Option { + self.next_object().map(|r| { + r.and_then(|mut tape| { + value_from_tape(&mut tape, &self.header.schema, &self.header.names) + }) + .and_then(|v| { + if let Some(schema) = &self.reader_schema { + v.resolve_internal(schema, &self.header.names, &None, &None) + } else { + Ok(v) + } + }) + }) + } +} + +/// Decode a raw Avro datum using the provided [`Schema`]. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +/// +/// **NOTE** This function is very niche and does NOT take care of reading the header and +/// consecutive data blocks. use [`Reader`] if you just want to read an Avro encoded file. +pub fn from_avro_datum( + writer_schema: &Schema, + reader: &mut R, + reader_schema: Option<&Schema>, +) -> AvroResult { + from_avro_datum_reader_schemata(writer_schema, Vec::new(), reader, reader_schema, Vec::new()) +} + +/// Decode a raw Avro datum using the provided [`Schema`] and schemata. +/// +/// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be +/// resolved and an error will be returned. +/// +/// If the writer schema contains any [`Schema::Ref`] then it will use the provided +/// schemata to resolve any dependencies. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +pub fn from_avro_datum_schemata( + writer_schema: &Schema, + writer_schemata: Vec<&Schema>, + reader: &mut R, + reader_schema: Option<&Schema>, +) -> AvroResult { + from_avro_datum_reader_schemata( + writer_schema, + writer_schemata, + reader, + reader_schema, + Vec::new(), + ) +} + +/// Decode a raw Avro datum using the provided [`Schema`] and schemata. +/// +/// `schema` must be in `schemata`. Otherwise, any self-referential [`Schema::Ref`]s can't be +/// resolved and an error will be returned. +/// +/// If the writer schema contains any [`Schema::Ref`] then it will use the provided +/// schemata to resolve any dependencies. +/// +/// In case a reader [`Schema`] is provided, schema resolution will be performed. +// TODO: These should really be a reusable reader, as quite a lot of work is done on creation +pub fn from_avro_datum_reader_schemata( + writer_schema: &Schema, + writer_schemata: Vec<&Schema>, + reader: &mut R, + reader_schema: Option<&Schema>, + reader_schemata: Vec<&Schema>, +) -> AvroResult { + let mut names = HashMap::new(); + if writer_schemata.is_empty() { + resolve_names(writer_schema, &mut names, &None)?; + } else { + resolve_names_with_schemata(&writer_schemata, &mut names, &None)?; + } + + let tape = CommandTape::build_from_schema(writer_schema, &names)?; + + // Read a maximum of 2Kb per read + let mut buffer = Buffer::with_capacity(2 * 1024); + let mut fsm = DatumStateMachine::new(tape); + let value = loop { + // Fill the buffer + let n = reader.read(buffer.space()).map_err(Details::ReadIntoBuf)?; + if n == 0 { + // If the writer schema is null, this is expected and we just return a null value + if matches!(writer_schema, &Schema::Null) { + break Value::Null; + } + return Err(Details::ReadIntoBuf(std::io::ErrorKind::UnexpectedEof.into()).into()); + } + buffer.fill(n); + + match fsm.parse(&mut buffer)? { + StateMachineControlFlow::NeedMore(new_fsm) => { + fsm = new_fsm; + } + StateMachineControlFlow::Done(mut tape) => { + break value_from_tape(&mut tape, writer_schema, &names)?; + } + } + }; + match reader_schema { + Some(schema) => { + if reader_schemata.is_empty() { + value.resolve(schema) + } else { + value.resolve_schemata(schema, reader_schemata) + } + } + None => Ok(value), + } +} + +#[cfg(test)] +mod tests { + use crate::{Schema, Writer, state_machines::reading::sync::Reader, types::Value}; + use std::io::Cursor; + + /// Test it reads all the sync markers + #[test] + fn sync_markers() { + let mut writer = Writer::new(&Schema::String, Vec::new()); + writer.append(Value::String("Hello".to_string())).unwrap(); + writer.flush().unwrap(); + writer.append(Value::String("World".to_string())).unwrap(); + writer.flush().unwrap(); + let mut written = Cursor::new(writer.into_inner().unwrap()); + + let mut reader = Reader::new(&mut written).unwrap(); + assert_eq!( + reader.next().unwrap().unwrap(), + Value::String("Hello".to_string()) + ); + assert_eq!( + reader.next().unwrap().unwrap(), + Value::String("World".to_string()) + ); + + drop(reader); + let position = written.position(); + let expected = written.into_inner().len(); + assert_eq!(position, expected as u64); + } +} diff --git a/avro/src/state_machines/reading/union.rs b/avro/src/state_machines/reading/union.rs new file mode 100644 index 00000000..ac5c5386 --- /dev/null +++ b/avro/src/state_machines/reading/union.rs @@ -0,0 +1,78 @@ +use crate::{ + error::Details, + state_machines::reading::{ + ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, SubStateMachine, + commands::CommandTape, decode_zigzag_buffer, + }, +}; +use oval::Buffer; + +enum VariantsOrFsm { + Variants { + variants: CommandTape, + read: Vec, + }, + Fsm(Box), +} + +pub struct UnionStateMachine { + variants_or_fsm: VariantsOrFsm, + num_variants: usize, +} + +impl UnionStateMachine { + pub fn new_with_tape(variants: CommandTape, num_variants: usize, read: Vec) -> Self { + Self { + variants_or_fsm: VariantsOrFsm::Variants { variants, read }, + num_variants, + } + } +} + +impl StateMachine for UnionStateMachine { + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + match self.variants_or_fsm { + VariantsOrFsm::Variants { + mut variants, + mut read, + } => { + let Some(index) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.variants_or_fsm = VariantsOrFsm::Variants { variants, read }; + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let option = + usize::try_from(index).map_err(|e| Details::ConvertI64ToUsize(e, index))?; + + variants.skip(option).ok_or(Details::GetUnionVariant { + index, + num_variants: self.num_variants, + })?; + + let variant = variants.command().ok_or(Details::GetUnionVariant { + index, + num_variants: self.num_variants, + })?; + + read.push(ItemRead::Union(u32::try_from(option).unwrap())); + + match variant.into_state_machine(read).parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.variants_or_fsm = VariantsOrFsm::Fsm(Box::new(fsm)); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + } + } + VariantsOrFsm::Fsm(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.variants_or_fsm = VariantsOrFsm::Fsm(Box::new(fsm)); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + } + } +} diff --git a/avro/src/types.rs b/avro/src/types.rs index 4448eef2..12e5909b 100644 --- a/avro/src/types.rs +++ b/avro/src/types.rs @@ -639,6 +639,7 @@ impl Value { mut self, schema: &Schema, names: &HashMap, + // TODO: These two should be Option<&T> instead of &Option enclosing_namespace: &Namespace, field_default: &Option, ) -> AvroResult { diff --git a/avro/src/util.rs b/avro/src/util.rs index 1c986f4b..8544362a 100644 --- a/avro/src/util.rs +++ b/avro/src/util.rs @@ -17,12 +17,9 @@ //! Utility functions, like configuring various global settings. -use crate::{AvroResult, error::Details, schema::Documentation}; +use crate::{AvroResult, Error, error::Details, schema::Documentation}; use serde_json::{Map, Value}; -use std::{ - io::{Read, Write}, - sync::OnceLock, -}; +use std::{io::Write, sync::OnceLock}; /// Maximum number of bytes that can be allocated when decoding /// Avro-encoded values. This is a protection against ill-formed @@ -74,10 +71,6 @@ impl MapHelper for Map { } } -pub(crate) fn read_long(reader: &mut R) -> AvroResult { - zag_i64(reader) -} - pub(crate) fn zig_i32(n: i32, buffer: W) -> AvroResult { zig_i64(n as i64, buffer) } @@ -86,20 +79,6 @@ pub(crate) fn zig_i64(n: i64, writer: W) -> AvroResult { encode_variable(((n << 1) ^ (n >> 63)) as u64, writer) } -pub(crate) fn zag_i32(reader: &mut R) -> AvroResult { - let i = zag_i64(reader)?; - i32::try_from(i).map_err(|e| Details::ZagI32(e, i).into()) -} - -pub(crate) fn zag_i64(reader: &mut R) -> AvroResult { - let z = decode_variable(reader)?; - Ok(if z & 0x1 == 0 { - (z >> 1) as i64 - } else { - !(z >> 1) as i64 - }) -} - fn encode_variable(mut z: u64, mut writer: W) -> AvroResult { let mut buffer = [0u8; 10]; let mut i: usize = 0; @@ -119,28 +98,66 @@ fn encode_variable(mut z: u64, mut writer: W) -> AvroResult { .map_err(|e| Details::WriteBytes(e).into()) } -fn decode_variable(reader: &mut R) -> AvroResult { - let mut i = 0u64; - let mut buf = [0u8; 1]; +/// Decode a zigzag encoded length. +/// +/// This version of [`decode_len`] will return a [`Details::ReadVariableIntegerBytes`] error if there are not +/// enough bytes and does not return the amount of bytes read. +/// +/// See [`decode_len`] for more details. +pub fn decode_len_simple(buffer: &[u8]) -> AvroResult<(usize, usize)> { + decode_len(buffer)?.ok_or_else(|| { + Details::ReadVariableIntegerBytes(std::io::ErrorKind::UnexpectedEof.into()).into() + }) +} - let mut j = 0; - loop { - if j > 9 { - // if j * 7 > 64 - return Err(Details::IntegerOverflow.into()); - } - reader - .read_exact(&mut buf[..]) - .map_err(Details::ReadVariableIntegerBytes)?; - i |= (u64::from(buf[0] & 0x7F)) << (j * 7); - if (buf[0] >> 7) == 0 { +/// Decode a zigzag encoded length. +/// +/// This will use [`safe_len`] to check if the length is in allowed bounds. +/// +/// # Returns +/// `Some(integer, bytes read)` if it completely read an integer, `None` if it did not have enough +/// bytes in the slice. +pub fn decode_len(buffer: &[u8]) -> AvroResult> { + if let Some((integer, bytes)) = decode_variable(buffer)? { + let length = + usize::try_from(integer).map_err(|e| Details::ConvertI64ToUsize(e, integer))?; + let safe = safe_len(length)?; + Ok(Some((safe, bytes))) + } else { + Ok(None) + } +} + +/// Decode a zigzag encoded integer. +/// +/// # Returns +/// `Some(integer, bytes read)` if it completely read an integer, `None` if it did not have enough +/// bytes in the slice. +pub fn decode_variable(buffer: &[u8]) -> Result, Error> { + let mut decoded = 0; + let mut loops_done = 0; + let mut last_byte = 0; + + for (counter, &byte) in buffer.iter().take(10).enumerate() { + decoded |= u64::from(byte & 0x7F) << (counter * 7); + loops_done = counter; + last_byte = byte; + if byte >> 7 == 0 { break; - } else { - j += 1; } } - Ok(i) + if last_byte >> 7 != 0 { + if loops_done == 10 { + Err(Details::IntegerOverflow.into()) + } else { + Ok(None) + } + } else if decoded & 0x1 == 0 { + Ok(Some(((decoded >> 1) as i64, loops_done + 1))) + } else { + Ok(Some((!(decoded >> 1) as i64, loops_done + 1))) + } } /// Set the maximum number of bytes that can be allocated when decoding data. @@ -282,8 +299,8 @@ mod tests { #[test] fn test_overflow() { - let causes_left_shift_overflow: &[u8] = &[0xe1, 0xe1, 0xe1, 0xe1, 0xe1]; - assert!(decode_variable(&mut &*causes_left_shift_overflow).is_err()); + let not_enough_bytes: &[u8] = &[0xe1, 0xe1, 0xe1, 0xe1, 0xe1]; + assert!(decode_variable(not_enough_bytes).unwrap().is_none()); } #[test] diff --git a/avro_derive/tests/derive.proptest-regressions b/avro_derive/tests/derive.proptest-regressions new file mode 100644 index 00000000..093d3789 --- /dev/null +++ b/avro_derive/tests/derive.proptest-regressions @@ -0,0 +1,12 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 7808c1b336ad8808516f0338a69eb961f7b70c7700e6d00f528b86c9ec48b9e7 # shrinks to a = false, b = 0, c = 0, d = 0, e = 0, f = 0, g = -4611686018427387905, h = 0.0, i = 0.0, j = "" +cc f9d4f3f6442d2c9a718a23a9450e94ec2cef74ac62776b068581376082f0aace # shrinks to a = 0, b = "" +cc 7bc3220d3d8a41cdbf5ab3618aeb4c8dd8268b524c1feb510d4ad39d104b261a # shrinks to a = "", b = [], c = {} +cc 2d98b332f36a28eba7e4a8d5e97856a4ba7ea54defb427285a70129cb4782b1a # shrinks to a = false, b = 0, c = 0, d = 0, e = 0, f = 0, g = 0, h = 0.0, i = -0.0, j = "" +cc 5b163ba13d686d86e78c8139e88278720019789f6841e58ead62807244f7582a # shrinks to a = false, b = 0, c = 0, d = 0, e = 0, f = 0, g = 0, h = 0.0, i = 0.0, j = "" +cc aa33e374dc52ac66d49a7bad56cb62eeb80c60f42c9dc8e831db1f581d4a2b07 # shrinks to a = false, b = 0, c = 0, d = 0, e = 0, f = 0, g = 0, h = 0.0, i = 0.0, j = "", aa = 0 diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..9fabc05e --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,2 @@ +imports_granularity="Crate" + From 029a7880e200ab87fad344ed38f91cce94930050 Mon Sep 17 00:00:00 2001 From: Kriskras99 Date: Tue, 28 Oct 2025 20:49:14 +0100 Subject: [PATCH 6/7] Reorganise `decode` and `reader` --- Cargo.lock | 20 ++ avro/Cargo.toml | 1 + .../reading => decode}/block.rs | 4 +- .../reading => decode}/bytes.rs | 2 +- .../reading => decode}/codec.rs | 12 +- .../reading => decode}/commands.rs | 8 +- .../reading => decode}/datum.rs | 6 +- .../reading => decode}/error.rs | 2 +- .../{state_machines/reading => decode}/mod.rs | 20 +- .../object_container_file.rs | 10 +- .../reading => decode}/union.rs | 4 +- avro/src/encode2/mod.rs | 1 + avro/src/error.rs | 2 +- avro/src/headers.rs | 14 + avro/src/lib.rs | 9 +- .../async_impl.rs => reader/asynch.rs} | 6 +- avro/src/{reader.rs => reader/mod.rs} | 339 ++---------------- .../reading => reader}/sync.rs | 314 +++++++++++++++- avro/src/state_machines/mod.rs | 1 - avro/src/util.rs | 6 +- 20 files changed, 417 insertions(+), 364 deletions(-) rename avro/src/{state_machines/reading => decode}/block.rs (99%) rename avro/src/{state_machines/reading => decode}/bytes.rs (96%) rename avro/src/{state_machines/reading => decode}/codec.rs (94%) rename avro/src/{state_machines/reading => decode}/commands.rs (99%) rename avro/src/{state_machines/reading => decode}/datum.rs (97%) rename avro/src/{state_machines/reading => decode}/error.rs (90%) rename avro/src/{state_machines/reading => decode}/mod.rs (99%) rename avro/src/{state_machines/reading => decode}/object_container_file.rs (99%) rename avro/src/{state_machines/reading => decode}/union.rs (98%) create mode 100644 avro/src/encode2/mod.rs rename avro/src/{state_machines/reading/async_impl.rs => reader/asynch.rs} (99%) rename avro/src/{reader.rs => reader/mod.rs} (52%) rename avro/src/{state_machines/reading => reader}/sync.rs (53%) delete mode 100644 avro/src/state_machines/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 8da21018..be9bbebe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -55,6 +55,7 @@ dependencies = [ "bigdecimal", "bon", "bzip2", + "corosensei", "crc32fast", "criterion", "digest", @@ -336,6 +337,19 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "corosensei" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d46a43097861058cb45affe888e40ba19b57a8210650144cdc7b50c9d87840a" +dependencies = [ + "autocfg", + "cfg-if", + "libc", + "scopeguard", + "windows-sys 0.59.0", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -1101,6 +1115,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "semver" version = "1.0.27" diff --git a/avro/Cargo.toml b/avro/Cargo.toml index e22a85ad..e4c16cc0 100644 --- a/avro/Cargo.toml +++ b/avro/Cargo.toml @@ -60,6 +60,7 @@ apache-avro-derive = { default-features = false, version = "0.20.0", path = "../ bigdecimal = { default-features = false, version = "0.4.9", features = ["std", "serde"] } bon = { default-features = false, version = "3.8.1" } bzip2 = { version = "0.6.1", optional = true } +corosensei = { version = "0.3.1" } crc32fast = { default-features = false, version = "1.5.0", optional = true } digest = { default-features = false, version = "0.10.7", features = ["core-api"] } miniz_oxide = { default-features = false, version = "0.8.9", features = ["with-alloc"] } diff --git a/avro/src/state_machines/reading/block.rs b/avro/src/decode/block.rs similarity index 99% rename from avro/src/state_machines/reading/block.rs rename to avro/src/decode/block.rs index cf03f6f2..8e3028fc 100644 --- a/avro/src/state_machines/reading/block.rs +++ b/avro/src/decode/block.rs @@ -2,11 +2,11 @@ use oval::Buffer; use crate::{ Error, - error::Details, - state_machines::reading::{ + decode::{ CommandTape, ItemRead, StateMachine, StateMachineControlFlow, datum::DatumStateMachine, decode_zigzag_buffer, }, + error::Details, }; /// Are we currently parsing an object or just finished/reading a block header diff --git a/avro/src/state_machines/reading/bytes.rs b/avro/src/decode/bytes.rs similarity index 96% rename from avro/src/state_machines/reading/bytes.rs rename to avro/src/decode/bytes.rs index 68bc529c..987e12f7 100644 --- a/avro/src/state_machines/reading/bytes.rs +++ b/avro/src/decode/bytes.rs @@ -1,8 +1,8 @@ use oval::Buffer; use crate::{ + decode::{StateMachine, StateMachineControlFlow, decode_zigzag_buffer}, error::Details, - state_machines::reading::{StateMachine, StateMachineControlFlow, decode_zigzag_buffer}, }; use super::StateMachineResult; diff --git a/avro/src/state_machines/reading/codec.rs b/avro/src/decode/codec.rs similarity index 94% rename from avro/src/state_machines/reading/codec.rs rename to avro/src/decode/codec.rs index b823ee6c..a27ff694 100644 --- a/avro/src/state_machines/reading/codec.rs +++ b/avro/src/decode/codec.rs @@ -1,6 +1,6 @@ use crate::{ Codec, - state_machines::reading::{StateMachine, StateMachineControlFlow, StateMachineResult}, + decode::{StateMachine, StateMachineControlFlow, StateMachineResult}, }; use oval::Buffer; @@ -36,7 +36,7 @@ pub enum Decoder { #[cfg(feature = "bzip")] Bzip2(bzip2::Decompress), #[cfg(feature = "xz")] - Xz(xz2::stream::Stream), + Xz(liblzma::stream::Stream), } impl From for Decoder { @@ -54,7 +54,9 @@ impl From for Decoder { #[cfg(feature = "bzip")] Codec::Bzip2(_) => Self::Bzip2(bzip2::Decompress::new(false)), #[cfg(feature = "xz")] - Codec::Xz(_) => Self::Xz(xz2::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap()), + Codec::Xz(_) => { + Self::Xz(liblzma::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap()) + } } } } @@ -80,7 +82,7 @@ impl Decoder { // No reset/reinit API available let _drop = std::mem::replace( decoder, - xz2::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap(), + liblzma::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap(), ); } } @@ -144,7 +146,7 @@ impl StateMachine for CodecStateMachine { } #[cfg(feature = "xz")] Decoder::Xz(decoder) => { - use xz2::stream::Action::Run; + use liblzma::stream::Action::Run; let prev_total_in = decoder.total_in(); let prev_total_out = decoder.total_out(); diff --git a/avro/src/state_machines/reading/commands.rs b/avro/src/decode/commands.rs similarity index 99% rename from avro/src/state_machines/reading/commands.rs rename to avro/src/decode/commands.rs index 734d5c04..fd607057 100644 --- a/avro/src/state_machines/reading/commands.rs +++ b/avro/src/decode/commands.rs @@ -1,14 +1,14 @@ use crate::{ Error, Schema, + decode::{ + ItemRead, SubStateMachine, block::BlockStateMachine, bytes::BytesStateMachine, + datum::DatumStateMachine, union::UnionStateMachine, + }, error::Details, schema::{ ArraySchema, DecimalSchema, EnumSchema, FixedSchema, MapSchema, Name, Names, RecordSchema, UnionSchema, }, - state_machines::reading::{ - ItemRead, SubStateMachine, block::BlockStateMachine, bytes::BytesStateMachine, - datum::DatumStateMachine, union::UnionStateMachine, - }, }; use std::{collections::HashMap, ops::Range, sync::Arc}; diff --git a/avro/src/state_machines/reading/datum.rs b/avro/src/decode/datum.rs similarity index 97% rename from avro/src/state_machines/reading/datum.rs rename to avro/src/decode/datum.rs index 2e9a4beb..163af72c 100644 --- a/avro/src/state_machines/reading/datum.rs +++ b/avro/src/decode/datum.rs @@ -1,8 +1,8 @@ use oval::Buffer; -use super::StateMachineResult; -use crate::state_machines::reading::{ - CommandTape, ItemRead, StateMachine, StateMachineControlFlow, SubStateMachine, +use crate::decode::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, + SubStateMachine, }; enum TapeOrFsm { diff --git a/avro/src/state_machines/reading/error.rs b/avro/src/decode/error.rs similarity index 90% rename from avro/src/state_machines/reading/error.rs rename to avro/src/decode/error.rs index 12bcefd8..1266b15d 100644 --- a/avro/src/state_machines/reading/error.rs +++ b/avro/src/decode/error.rs @@ -1,4 +1,4 @@ -use crate::{Schema, state_machines::reading::ItemRead}; +use crate::{Schema, decode::ItemRead}; use thiserror::Error; #[derive(Error, Debug)] diff --git a/avro/src/state_machines/reading/mod.rs b/avro/src/decode/mod.rs similarity index 99% rename from avro/src/state_machines/reading/mod.rs rename to avro/src/decode/mod.rs index 465642ac..a2dda7fc 100644 --- a/avro/src/state_machines/reading/mod.rs +++ b/avro/src/decode/mod.rs @@ -1,15 +1,15 @@ use crate::{ Decimal, Duration, Error, Schema, bigdecimal::deserialize_big_decimal, + decode::{ + block::BlockStateMachine, bytes::BytesStateMachine, commands::CommandTape, + datum::DatumStateMachine, error::ValueFromTapeError, union::UnionStateMachine, + }, error::Details, schema::{ ArraySchema, EnumSchema, FixedSchema, MapSchema, Name, Names, Namespace, RecordSchema, ResolvedSchema, UnionSchema, }, - state_machines::reading::{ - block::BlockStateMachine, bytes::BytesStateMachine, commands::CommandTape, - datum::DatumStateMachine, error::ValueFromTapeError, union::UnionStateMachine, - }, types::Value, util::decode_variable, }; @@ -18,15 +18,13 @@ use serde::Deserialize; use std::{borrow::Borrow, collections::HashMap, io::Read, ops::Deref, str::FromStr}; use uuid::Uuid; -pub mod async_impl; pub mod block; pub mod bytes; pub mod codec; -mod commands; +pub mod commands; pub mod datum; pub mod error; -mod object_container_file; -pub mod sync; +pub mod object_container_file; mod union; pub trait StateMachine: Sized { @@ -43,7 +41,7 @@ pub trait StateMachine: Sized { pub enum StateMachineControlFlow { /// The state machine needs more data before it can continue. NeedMore(StateMachine), - /// The state machine is done and the result is returned.s + /// The state machine is done and the result is returned. Done(Output), } @@ -52,9 +50,9 @@ pub type StateMachineResult = /// The sub state machine that is currently being driven. /// -/// The `Int`, `Long`, `Float`, `Double`, and `Enum` statemachines don't have state, as +/// The `Int`, `Long`, `Float`, `Double`, and `Enum` state machines don't have state, as /// they don't consume the buffer if there are not enough bytes. This means that the only -/// thing these statemachines are keeping track of is which type we're actually decoding. +/// thing these state machines are keeping track of is which type we're actually decoding. pub enum SubStateMachine { Null(Vec), Bool(Vec), diff --git a/avro/src/state_machines/reading/object_container_file.rs b/avro/src/decode/object_container_file.rs similarity index 99% rename from avro/src/state_machines/reading/object_container_file.rs rename to avro/src/decode/object_container_file.rs index 8f9c2bd5..d1f111c1 100644 --- a/avro/src/state_machines/reading/object_container_file.rs +++ b/avro/src/decode/object_container_file.rs @@ -1,11 +1,11 @@ use crate::{ Codec, Error, Schema, - error::Details, - schema::{Names, ResolvedSchema, resolve_names, resolve_names_with_schemata}, - state_machines::reading::{ + decode::{ CommandTape, ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, codec::CodecStateMachine, datum::DatumStateMachine, decode_zigzag_buffer, }, + error::Details, + schema::{Names, ResolvedSchema, resolve_names, resolve_names_with_schemata}, }; use log::warn; use oval::Buffer; @@ -59,7 +59,7 @@ impl ObjectContainerFileHeader { /// # Panics /// Will panic if the tape was not produced from [`Self::command_tape()`]. pub fn from_tape(mut tape: Vec, mut schemata: Vec<&Schema>) -> Result { - // We want to read the tape from front to back + // Vec::remove(0) is an O(N) operation, so we use `drain` to read from front to back let mut tape = tape.drain(..); let mut schema = None; @@ -303,7 +303,7 @@ mod tests { use crate::{ Schema, - state_machines::reading::{ + decode::{ commands::CommandTape, object_container_file::{HEADER_JSON, HEADER_TAPE}, }, diff --git a/avro/src/state_machines/reading/union.rs b/avro/src/decode/union.rs similarity index 98% rename from avro/src/state_machines/reading/union.rs rename to avro/src/decode/union.rs index ac5c5386..05d11662 100644 --- a/avro/src/state_machines/reading/union.rs +++ b/avro/src/decode/union.rs @@ -1,9 +1,9 @@ use crate::{ - error::Details, - state_machines::reading::{ + decode::{ ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, SubStateMachine, commands::CommandTape, decode_zigzag_buffer, }, + error::Details, }; use oval::Buffer; diff --git a/avro/src/encode2/mod.rs b/avro/src/encode2/mod.rs new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/avro/src/encode2/mod.rs @@ -0,0 +1 @@ + diff --git a/avro/src/error.rs b/avro/src/error.rs index 2163d050..77c76538 100644 --- a/avro/src/error.rs +++ b/avro/src/error.rs @@ -16,8 +16,8 @@ // under the License. use crate::{ + decode::error::ValueFromTapeError, schema::{Name, Schema, SchemaKind, UnionSchema}, - state_machines::reading::error::ValueFromTapeError, types::{Value, ValueKind}, }; use std::{error::Error as _, fmt}; diff --git a/avro/src/headers.rs b/avro/src/headers.rs index dce134f1..b09f4634 100644 --- a/avro/src/headers.rs +++ b/avro/src/headers.rs @@ -110,6 +110,20 @@ impl HeaderBuilder for GlueSchemaUuidHeader { } } +/// No header is used in the file +/// +/// This is useful if you want to use `from_avro_datum` without resolving the +/// schema every call, instead you can use this with [`GenericSingleObjectReader::new_with_header_builder`]. +/// +/// [`GenericSingleObjectReader::new_with_header_builder`]: crate::reader::GenericSingleObjectReader::new_with_header_builder +pub struct NoHeader; + +impl HeaderBuilder for NoHeader { + fn build_header(&self) -> Vec { + Vec::new() + } +} + #[cfg(test)] mod test { use super::*; diff --git a/avro/src/lib.rs b/avro/src/lib.rs index baea6e10..35ea55d1 100644 --- a/avro/src/lib.rs +++ b/avro/src/lib.rs @@ -949,21 +949,22 @@ mod de; mod decimal; mod duration; mod encode; -mod reader; mod ser; mod ser_schema; -mod writer; +pub mod decode; +pub mod encode2; pub mod error; pub mod headers; pub mod rabin; +pub mod reader; pub mod schema; pub mod schema_compatibility; pub mod schema_equality; -pub mod state_machines; pub mod types; pub mod util; pub mod validator; +pub mod writer; pub use crate::{ bigdecimal::BigDecimal, @@ -1041,7 +1042,7 @@ pub fn set_serde_human_readable(human_readable: bool) -> bool { /// Async versions of the types and functions. pub mod not_sync { #[doc(inline)] - pub use crate::reader::async_reader::*; + pub use crate::reader::asynch::*; } #[cfg(test)] diff --git a/avro/src/state_machines/reading/async_impl.rs b/avro/src/reader/asynch.rs similarity index 99% rename from avro/src/state_machines/reading/async_impl.rs rename to avro/src/reader/asynch.rs index c4000f1f..4a5149d9 100644 --- a/avro/src/state_machines/reading/async_impl.rs +++ b/avro/src/reader/asynch.rs @@ -6,9 +6,7 @@ use std::collections::HashMap; use crate::{ AvroResult, Error, Schema, - error::Details, - schema::{resolve_names, resolve_names_with_schemata}, - state_machines::reading::{ + decode::{ ItemRead, StateMachine, StateMachineControlFlow, commands::CommandTape, datum::DatumStateMachine, @@ -19,6 +17,8 @@ use crate::{ }, value_from_tape, }, + error::Details, + schema::{resolve_names, resolve_names_with_schemata}, types::Value, }; diff --git a/avro/src/reader.rs b/avro/src/reader/mod.rs similarity index 52% rename from avro/src/reader.rs rename to avro/src/reader/mod.rs index 5590ddd3..0ac75c5c 100644 --- a/avro/src/reader.rs +++ b/avro/src/reader/mod.rs @@ -1,42 +1,33 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. +pub mod asynch; +pub mod sync; -//! Logic handling reading from Avro format at user level. - -pub use crate::state_machines::reading::sync::{ - Reader, from_avro_datum, from_avro_datum_reader_schemata, from_avro_datum_schemata, -}; use crate::{ - AvroResult, + AvroResult, AvroSchema, Schema, error::Details, from_value, headers::{HeaderBuilder, RabinFingerprintHeader}, - schema::{AvroSchema, ResolvedOwnedSchema, Schema}, + schema::ResolvedOwnedSchema, types::Value, }; use futures::AsyncRead; use serde::de::DeserializeOwned; use std::{io::Read, marker::PhantomData}; +pub use sync::*; -pub mod async_reader { - #[doc(inline)] - pub use crate::state_machines::reading::async_impl::{ - Reader, from_avro_datum, from_avro_datum_reader_schemata, from_avro_datum_schemata, - }; +// This is for API compatibility with previous versions +pub use sync::*; + +/// Reads the marker bytes from Avro bytes generated earlier by a [`Writer`]. +/// +/// [`Writer`]: crate::Writer +pub fn read_marker(bytes: &[u8]) -> [u8; 16] { + assert!( + bytes.len() > 16, + "The bytes are too short to read a marker from them" + ); + let mut marker = [0_u8; 16]; + marker.clone_from_slice(&bytes[(bytes.len() - 16)..]); + marker } /// Reader for Avro objects created using the [single-object encoding]. @@ -98,8 +89,7 @@ impl GenericSingleObjectReader { match reader.read_exact(&mut header).await { Ok(_) => { if self.expected_header == header { - async_reader::from_avro_datum(self.write_schema.get_root_schema(), reader, None) - .await + asynch::from_avro_datum(self.write_schema.get_root_schema(), reader, None).await } else { Err( Details::SingleObjectHeaderMismatch(self.expected_header.clone(), header) @@ -168,269 +158,17 @@ where } } -/// Reads the marker bytes from Avro bytes generated earlier by a [`Writer`]. -/// -/// [`Writer`]: crate::Writer -pub fn read_marker(bytes: &[u8]) -> [u8; 16] { - assert!( - bytes.len() > 16, - "The bytes are too short to read a marker from them" - ); - let mut marker = [0_u8; 16]; - marker.clone_from_slice(&bytes[(bytes.len() - 16)..]); - marker -} - #[cfg(test)] mod tests { use super::*; use crate::{ - Error, encode::encode, headers::GlueSchemaUuidHeader, rabin::Rabin, types::Record, + AvroSchema, Error, Schema, encode::encode, error::Details, headers::GlueSchemaUuidHeader, + rabin::Rabin, types::Value, }; use apache_avro_test_helper::TestResult; - use pretty_assertions::assert_eq; use serde::Deserialize; - use std::{collections::HashMap, io::Cursor}; use uuid::Uuid; - const SCHEMA: &str = r#" - { - "type": "record", - "name": "test", - "fields": [ - { - "name": "a", - "type": "long", - "default": 42 - }, - { - "name": "b", - "type": "string" - } - ] - } - "#; - const UNION_SCHEMA: &str = r#"["null", "long"]"#; - const ENCODED: &[u8] = &[ - 79u8, 98u8, 106u8, 1u8, 4u8, 22u8, 97u8, 118u8, 114u8, 111u8, 46u8, 115u8, 99u8, 104u8, - 101u8, 109u8, 97u8, 222u8, 1u8, 123u8, 34u8, 116u8, 121u8, 112u8, 101u8, 34u8, 58u8, 34u8, - 114u8, 101u8, 99u8, 111u8, 114u8, 100u8, 34u8, 44u8, 34u8, 110u8, 97u8, 109u8, 101u8, 34u8, - 58u8, 34u8, 116u8, 101u8, 115u8, 116u8, 34u8, 44u8, 34u8, 102u8, 105u8, 101u8, 108u8, - 100u8, 115u8, 34u8, 58u8, 91u8, 123u8, 34u8, 110u8, 97u8, 109u8, 101u8, 34u8, 58u8, 34u8, - 97u8, 34u8, 44u8, 34u8, 116u8, 121u8, 112u8, 101u8, 34u8, 58u8, 34u8, 108u8, 111u8, 110u8, - 103u8, 34u8, 44u8, 34u8, 100u8, 101u8, 102u8, 97u8, 117u8, 108u8, 116u8, 34u8, 58u8, 52u8, - 50u8, 125u8, 44u8, 123u8, 34u8, 110u8, 97u8, 109u8, 101u8, 34u8, 58u8, 34u8, 98u8, 34u8, - 44u8, 34u8, 116u8, 121u8, 112u8, 101u8, 34u8, 58u8, 34u8, 115u8, 116u8, 114u8, 105u8, - 110u8, 103u8, 34u8, 125u8, 93u8, 125u8, 20u8, 97u8, 118u8, 114u8, 111u8, 46u8, 99u8, 111u8, - 100u8, 101u8, 99u8, 8u8, 110u8, 117u8, 108u8, 108u8, 0u8, 94u8, 61u8, 54u8, 221u8, 190u8, - 207u8, 108u8, 180u8, 158u8, 57u8, 114u8, 40u8, 173u8, 199u8, 228u8, 239u8, 4u8, 20u8, 54u8, - 6u8, 102u8, 111u8, 111u8, 84u8, 6u8, 98u8, 97u8, 114u8, 94u8, 61u8, 54u8, 221u8, 190u8, - 207u8, 108u8, 180u8, 158u8, 57u8, 114u8, 40u8, 173u8, 199u8, 228u8, 239u8, - ]; - - #[test] - fn test_from_avro_datum() -> TestResult { - let schema = Schema::parse_str(SCHEMA)?; - let mut encoded: &'static [u8] = &[54, 6, 102, 111, 111]; - - let mut record = Record::new(&schema).unwrap(); - record.put("a", 27i64); - record.put("b", "foo"); - let expected = record.into(); - - assert_eq!(from_avro_datum(&schema, &mut encoded, None)?, expected); - - Ok(()) - } - - #[test] - fn test_from_avro_datum_with_union_to_struct() -> TestResult { - const TEST_RECORD_SCHEMA_3240: &str = r#" - { - "type": "record", - "name": "test", - "fields": [ - { - "name": "a", - "type": "long", - "default": 42 - }, - { - "name": "b", - "type": "string" - }, - { - "name": "a_nullable_array", - "type": ["null", {"type": "array", "items": {"type": "string"}}], - "default": null - }, - { - "name": "a_nullable_boolean", - "type": ["null", {"type": "boolean"}], - "default": null - }, - { - "name": "a_nullable_string", - "type": ["null", {"type": "string"}], - "default": null - } - ] - } - "#; - #[derive(Default, Debug, Deserialize, PartialEq, Eq)] - struct TestRecord3240 { - a: i64, - b: String, - a_nullable_array: Option>, - // we are missing the 'a_nullable_boolean' field to simulate missing keys - // a_nullable_boolean: Option, - a_nullable_string: Option, - } - - let schema = Schema::parse_str(TEST_RECORD_SCHEMA_3240)?; - let mut encoded: &'static [u8] = &[54, 6, 102, 111, 111]; - - // The schema used to read is not compatible with what is written - assert!(from_avro_datum(&schema, &mut encoded, None).is_err()); - - // let avro_datum = from_avro_datum(&schema, &mut encoded, None)?; - - // let expected_record: TestRecord3240 = TestRecord3240 { - // a: 27i64, - // b: String::from("foo"), - // a_nullable_array: None, - // a_nullable_string: None, - // }; - - // let parsed_record: TestRecord3240 = match &avro_datum { - // Value::Record(_) => from_value::(&avro_datum)?, - // unexpected => { - // panic!("could not map avro data to struct, found unexpected: {unexpected:?}") - // } - // }; - // - // assert_eq!(parsed_record, expected_record); - - Ok(()) - } - - #[test] - fn test_null_union() -> TestResult { - let schema = Schema::parse_str(UNION_SCHEMA)?; - let mut encoded: &'static [u8] = &[2, 0]; - - assert_eq!( - from_avro_datum(&schema, &mut encoded, None)?, - Value::Union(1, Box::new(Value::Long(0))) - ); - - Ok(()) - } - - #[test] - fn test_reader_iterator() -> TestResult { - let schema = Schema::parse_str(SCHEMA)?; - let reader = Reader::with_schema(&schema, ENCODED)?; - - let mut record1 = Record::new(&schema).unwrap(); - record1.put("a", 27i64); - record1.put("b", "foo"); - - let mut record2 = Record::new(&schema).unwrap(); - record2.put("a", 42i64); - record2.put("b", "bar"); - - let expected = [record1.into(), record2.into()]; - - for (i, value) in reader.enumerate() { - assert_eq!(value?, expected[i]); - } - - Ok(()) - } - - #[test] - fn test_reader_invalid_header() -> TestResult { - let schema = Schema::parse_str(SCHEMA)?; - let invalid = ENCODED.iter().copied().skip(1).collect::>(); - assert!(Reader::with_schema(&schema, &invalid[..]).is_err()); - - Ok(()) - } - - #[test] - fn test_reader_invalid_block() -> TestResult { - let schema = Schema::parse_str(SCHEMA)?; - let invalid = ENCODED - .iter() - .copied() - .rev() - .skip(19) - .collect::>() - .into_iter() - .rev() - .collect::>(); - let mut reader = Reader::with_schema(&schema, &invalid[..])?; - - // The block says it contains 2 values, but only contains one. - // The first value is successfully decoded - let _v = reader.next().unwrap().unwrap(); - // The second fails with an unexpected end of file error. - assert!(reader.next().unwrap().is_err()); - - Ok(()) - } - - #[test] - fn test_reader_empty_buffer() -> TestResult { - let empty = Cursor::new(Vec::new()); - assert!(Reader::new(empty).is_err()); - - Ok(()) - } - - #[test] - fn test_reader_only_header() -> TestResult { - let invalid = ENCODED.iter().copied().take(165).collect::>(); - let reader = Reader::new(&invalid[..])?; - for value in reader { - assert!(value.is_err()); - } - - Ok(()) - } - - #[test] - fn test_avro_3405_read_user_metadata_success() -> TestResult { - use crate::writer::Writer; - - let schema = Schema::parse_str(SCHEMA)?; - let mut writer = Writer::new(&schema, Vec::new())?; - - let mut user_meta_data: HashMap> = HashMap::new(); - user_meta_data.insert("stringKey".to_string(), b"stringValue".to_vec()); - user_meta_data.insert("bytesKey".to_string(), b"bytesValue".to_vec()); - user_meta_data.insert("vecKey".to_string(), vec![1, 2, 3]); - - for (k, v) in user_meta_data.iter() { - writer.add_user_metadata(k.to_string(), v)?; - } - - let mut record = Record::new(&schema).unwrap(); - record.put("a", 27i64); - record.put("b", "foo"); - - writer.append(record.clone())?; - writer.append(record.clone())?; - writer.flush()?; - let result = writer.into_inner()?; - - let reader = Reader::new(&result[..])?; - assert_eq!(reader.user_metadata(), &user_meta_data); - - Ok(()) - } - #[derive(Deserialize, Clone, PartialEq, Debug)] struct TestSingleObjectReader { a: i64, @@ -538,7 +276,7 @@ mod tests { .read_value(&mut to_read) .expect("Should read"); let expected_value: Value = obj.into(); - assert_eq!(expected_value, val); + pretty_assertions::assert_eq!(expected_value, val); Ok(()) } @@ -572,7 +310,7 @@ mod tests { .read_value(&mut to_read) .expect("Should read"); let expected_value: Value = obj.into(); - assert_eq!(expected_value, val); + pretty_assertions::assert_eq!(expected_value, val); Ok(()) } @@ -616,9 +354,9 @@ mod tests { .read(&mut to_read3) .expect("Should read from deserilize"); let expected_value: Value = obj.clone().into(); - assert_eq!(obj, read_obj1); - assert_eq!(obj, read_obj2); - assert_eq!(val, expected_value); + pretty_assertions::assert_eq!(obj, read_obj1); + pretty_assertions::assert_eq!(obj, read_obj2); + pretty_assertions::assert_eq!(val, expected_value); Ok(()) } @@ -642,27 +380,4 @@ mod tests { matches!(read_result, Err(Details::ReadBytes(_))); Ok(()) } - - #[cfg(not(feature = "snappy"))] - #[test] - fn test_avro_3549_read_not_enabled_codec() { - let snappy_compressed_avro = vec![ - 79, 98, 106, 1, 4, 22, 97, 118, 114, 111, 46, 115, 99, 104, 101, 109, 97, 210, 1, 123, - 34, 102, 105, 101, 108, 100, 115, 34, 58, 91, 123, 34, 110, 97, 109, 101, 34, 58, 34, - 110, 117, 109, 34, 44, 34, 116, 121, 112, 101, 34, 58, 34, 115, 116, 114, 105, 110, - 103, 34, 125, 93, 44, 34, 110, 97, 109, 101, 34, 58, 34, 101, 118, 101, 110, 116, 34, - 44, 34, 110, 97, 109, 101, 115, 112, 97, 99, 101, 34, 58, 34, 101, 120, 97, 109, 112, - 108, 101, 110, 97, 109, 101, 115, 112, 97, 99, 101, 34, 44, 34, 116, 121, 112, 101, 34, - 58, 34, 114, 101, 99, 111, 114, 100, 34, 125, 20, 97, 118, 114, 111, 46, 99, 111, 100, - 101, 99, 12, 115, 110, 97, 112, 112, 121, 0, 213, 209, 241, 208, 200, 110, 164, 47, - 203, 25, 90, 235, 161, 167, 195, 177, 2, 20, 4, 12, 6, 49, 50, 51, 115, 38, 58, 0, 213, - 209, 241, 208, 200, 110, 164, 47, 203, 25, 90, 235, 161, 167, 195, 177, - ]; - - if let Err(err) = Reader::new(snappy_compressed_avro.as_slice()) { - assert_eq!("Codec 'snappy' is not supported/enabled", err.to_string()); - } else { - panic!("Expected an error in the reading of the codec!"); - } - } } diff --git a/avro/src/state_machines/reading/sync.rs b/avro/src/reader/sync.rs similarity index 53% rename from avro/src/state_machines/reading/sync.rs rename to avro/src/reader/sync.rs index daab82a0..c0f44648 100644 --- a/avro/src/state_machines/reading/sync.rs +++ b/avro/src/reader/sync.rs @@ -4,9 +4,7 @@ use std::{collections::HashMap, io::Read}; use crate::{ AvroResult, Error, Schema, - error::Details, - schema::{resolve_names, resolve_names_with_schemata}, - state_machines::reading::{ + decode::{ ItemRead, StateMachine, StateMachineControlFlow, commands::CommandTape, datum::DatumStateMachine, @@ -17,6 +15,8 @@ use crate::{ }, value_from_tape, }, + error::Details, + schema::{resolve_names, resolve_names_with_schemata}, types::Value, }; @@ -282,6 +282,8 @@ pub fn from_avro_datum_reader_schemata( let tape = CommandTape::build_from_schema(writer_schema, &names)?; + println!("{tape:#?}"); + // Read a maximum of 2Kb per read let mut buffer = Buffer::with_capacity(2 * 1024); let mut fsm = DatumStateMachine::new(tape); @@ -320,13 +322,20 @@ pub fn from_avro_datum_reader_schemata( #[cfg(test)] mod tests { - use crate::{Schema, Writer, state_machines::reading::sync::Reader, types::Value}; - use std::io::Cursor; + use super::*; + use crate::{ + Schema, Writer, from_value, + types::{Record, Value}, + }; + use apache_avro_test_helper::TestResult; + use pretty_assertions::assert_eq; + use serde::Deserialize; + use std::{collections::HashMap, io::Cursor}; /// Test it reads all the sync markers #[test] fn sync_markers() { - let mut writer = Writer::new(&Schema::String, Vec::new()); + let mut writer = Writer::new(&Schema::String, Vec::new()).unwrap(); writer.append(Value::String("Hello".to_string())).unwrap(); writer.flush().unwrap(); writer.append(Value::String("World".to_string())).unwrap(); @@ -348,4 +357,297 @@ mod tests { let expected = written.into_inner().len(); assert_eq!(position, expected as u64); } + + const SCHEMA: &str = r#" + { + "type": "record", + "name": "test", + "fields": [ + { + "name": "a", + "type": "long", + "default": 42 + }, + { + "name": "b", + "type": "string" + } + ] + } + "#; + const UNION_SCHEMA: &str = r#"["null", "long"]"#; + const ENCODED: &[u8] = &[ + 79u8, 98u8, 106u8, 1u8, 4u8, 22u8, 97u8, 118u8, 114u8, 111u8, 46u8, 115u8, 99u8, 104u8, + 101u8, 109u8, 97u8, 222u8, 1u8, 123u8, 34u8, 116u8, 121u8, 112u8, 101u8, 34u8, 58u8, 34u8, + 114u8, 101u8, 99u8, 111u8, 114u8, 100u8, 34u8, 44u8, 34u8, 110u8, 97u8, 109u8, 101u8, 34u8, + 58u8, 34u8, 116u8, 101u8, 115u8, 116u8, 34u8, 44u8, 34u8, 102u8, 105u8, 101u8, 108u8, + 100u8, 115u8, 34u8, 58u8, 91u8, 123u8, 34u8, 110u8, 97u8, 109u8, 101u8, 34u8, 58u8, 34u8, + 97u8, 34u8, 44u8, 34u8, 116u8, 121u8, 112u8, 101u8, 34u8, 58u8, 34u8, 108u8, 111u8, 110u8, + 103u8, 34u8, 44u8, 34u8, 100u8, 101u8, 102u8, 97u8, 117u8, 108u8, 116u8, 34u8, 58u8, 52u8, + 50u8, 125u8, 44u8, 123u8, 34u8, 110u8, 97u8, 109u8, 101u8, 34u8, 58u8, 34u8, 98u8, 34u8, + 44u8, 34u8, 116u8, 121u8, 112u8, 101u8, 34u8, 58u8, 34u8, 115u8, 116u8, 114u8, 105u8, + 110u8, 103u8, 34u8, 125u8, 93u8, 125u8, 20u8, 97u8, 118u8, 114u8, 111u8, 46u8, 99u8, 111u8, + 100u8, 101u8, 99u8, 8u8, 110u8, 117u8, 108u8, 108u8, 0u8, 94u8, 61u8, 54u8, 221u8, 190u8, + 207u8, 108u8, 180u8, 158u8, 57u8, 114u8, 40u8, 173u8, 199u8, 228u8, 239u8, 4u8, 20u8, 54u8, + 6u8, 102u8, 111u8, 111u8, 84u8, 6u8, 98u8, 97u8, 114u8, 94u8, 61u8, 54u8, 221u8, 190u8, + 207u8, 108u8, 180u8, 158u8, 57u8, 114u8, 40u8, 173u8, 199u8, 228u8, 239u8, + ]; + + #[test] + fn test_from_avro_datum() -> TestResult { + let schema = Schema::parse_str(SCHEMA)?; + let mut encoded: &'static [u8] = &[54, 6, 102, 111, 111]; + + let mut record = Record::new(&schema).unwrap(); + record.put("a", 27i64); + record.put("b", "foo"); + let expected = record.into(); + + assert_eq!(from_avro_datum(&schema, &mut encoded, None)?, expected); + + Ok(()) + } + + #[test] + fn test_from_avro_datum_with_union_to_struct() -> TestResult { + // TODO: Inform @ultrabug that their fix has been reverted and they need to use from_avro_datum_reader_schemata + const TEST_RECORD_WRITER_SCHEMA_3240: &str = r#" + { + "type": "record", + "name": "test", + "fields": [ + { + "name": "a", + "type": "long" + }, + { + "name": "b", + "type": "string" + } + ] + } + "#; + + const TEST_RECORD_READER_SCHEMA_3240: &str = r#" + { + "type": "record", + "name": "test", + "fields": [ + { + "name": "a", + "type": "long", + "default": 42 + }, + { + "name": "b", + "type": "string" + }, + { + "name": "a_nullable_array", + "type": ["null", {"type": "array", "items": {"type": "string"}}], + "default": null + }, + { + "name": "a_nullable_boolean", + "type": ["null", {"type": "boolean"}], + "default": null + }, + { + "name": "a_nullable_string", + "type": ["null", {"type": "string"}], + "default": null + } + ] + } + "#; + const TEST_DATUM_3240: &[u8] = &[54, 6, 102, 111, 111]; + + #[derive(Default, Debug, Deserialize, PartialEq, Eq)] + struct TestRecord3240 { + a: i64, + b: String, + a_nullable_array: Option>, + // we are missing the 'a_nullable_boolean' field to simulate missing keys + // a_nullable_boolean: Option, + a_nullable_string: Option, + } + + let reader_schema = Schema::parse_str(TEST_RECORD_READER_SCHEMA_3240)?; + let writer_schema = Schema::parse_str(TEST_RECORD_WRITER_SCHEMA_3240)?; + + // The schema used to read is not compatible with what is written + assert!(from_avro_datum(&reader_schema, &mut Cursor::new(TEST_DATUM_3240), None).is_err()); + // If the writer schema is used, it is compatible + assert!(from_avro_datum(&writer_schema, &mut Cursor::new(TEST_DATUM_3240), None).is_ok()); + + // For schema compatibility use the writer and reader schema + let avro_datum = from_avro_datum_reader_schemata( + &writer_schema, + Vec::new(), + &mut Cursor::new(TEST_DATUM_3240), + Some(&reader_schema), + Vec::new(), + )?; + + let expected_record: TestRecord3240 = TestRecord3240 { + a: 27i64, + b: String::from("foo"), + a_nullable_array: None, + a_nullable_string: None, + }; + + let parsed_record: TestRecord3240 = match &avro_datum { + Value::Record(_) => from_value::(&avro_datum)?, + unexpected => { + panic!("could not map avro data to struct, found unexpected: {unexpected:?}") + } + }; + + assert_eq!(parsed_record, expected_record); + + Ok(()) + } + + #[test] + fn test_null_union() -> TestResult { + let schema = Schema::parse_str(UNION_SCHEMA)?; + let mut encoded: &'static [u8] = &[2, 0]; + + assert_eq!( + from_avro_datum(&schema, &mut encoded, None)?, + Value::Union(1, Box::new(Value::Long(0))) + ); + + Ok(()) + } + + #[test] + fn test_reader_iterator() -> TestResult { + let schema = Schema::parse_str(SCHEMA)?; + let reader = Reader::with_schema(&schema, ENCODED)?; + + let mut record1 = Record::new(&schema).unwrap(); + record1.put("a", 27i64); + record1.put("b", "foo"); + + let mut record2 = Record::new(&schema).unwrap(); + record2.put("a", 42i64); + record2.put("b", "bar"); + + let expected = [record1.into(), record2.into()]; + + for (i, value) in reader.enumerate() { + assert_eq!(value?, expected[i]); + } + + Ok(()) + } + + #[test] + fn test_reader_invalid_header() -> TestResult { + let schema = Schema::parse_str(SCHEMA)?; + let invalid = ENCODED.iter().copied().skip(1).collect::>(); + assert!(Reader::with_schema(&schema, &invalid[..]).is_err()); + + Ok(()) + } + + #[test] + fn test_reader_invalid_block() -> TestResult { + let schema = Schema::parse_str(SCHEMA)?; + let invalid = ENCODED + .iter() + .copied() + .rev() + .skip(19) + .collect::>() + .into_iter() + .rev() + .collect::>(); + let mut reader = Reader::with_schema(&schema, &invalid[..])?; + + // The block says it contains 2 values, but only contains one. + // The first value is successfully decoded + let _v = reader.next().unwrap().unwrap(); + // The second fails with an unexpected end of file error. + assert!(reader.next().unwrap().is_err()); + + Ok(()) + } + + #[test] + fn test_reader_empty_buffer() -> TestResult { + let empty = Cursor::new(Vec::new()); + assert!(Reader::new(empty).is_err()); + + Ok(()) + } + + #[test] + fn test_reader_only_header() -> TestResult { + let invalid = ENCODED.iter().copied().take(165).collect::>(); + let reader = Reader::new(&invalid[..])?; + for value in reader { + assert!(value.is_err()); + } + + Ok(()) + } + + #[test] + fn test_avro_3405_read_user_metadata_success() -> TestResult { + use crate::writer::Writer; + + let schema = Schema::parse_str(SCHEMA)?; + let mut writer = Writer::new(&schema, Vec::new()).unwrap(); + + let mut user_meta_data: HashMap> = HashMap::new(); + user_meta_data.insert("stringKey".to_string(), b"stringValue".to_vec()); + user_meta_data.insert("bytesKey".to_string(), b"bytesValue".to_vec()); + user_meta_data.insert("vecKey".to_string(), vec![1, 2, 3]); + + for (k, v) in user_meta_data.iter() { + writer.add_user_metadata(k.to_string(), v)?; + } + + let mut record = Record::new(&schema).unwrap(); + record.put("a", 27i64); + record.put("b", "foo"); + + writer.append(record.clone())?; + writer.append(record.clone())?; + writer.flush()?; + let result = writer.into_inner()?; + + let reader = Reader::new(&result[..])?; + assert_eq!(reader.user_metadata(), &user_meta_data); + + Ok(()) + } + + #[cfg(not(feature = "snappy"))] + #[test] + fn test_avro_3549_read_not_enabled_codec() { + let snappy_compressed_avro = vec![ + 79, 98, 106, 1, 4, 22, 97, 118, 114, 111, 46, 115, 99, 104, 101, 109, 97, 210, 1, 123, + 34, 102, 105, 101, 108, 100, 115, 34, 58, 91, 123, 34, 110, 97, 109, 101, 34, 58, 34, + 110, 117, 109, 34, 44, 34, 116, 121, 112, 101, 34, 58, 34, 115, 116, 114, 105, 110, + 103, 34, 125, 93, 44, 34, 110, 97, 109, 101, 34, 58, 34, 101, 118, 101, 110, 116, 34, + 44, 34, 110, 97, 109, 101, 115, 112, 97, 99, 101, 34, 58, 34, 101, 120, 97, 109, 112, + 108, 101, 110, 97, 109, 101, 115, 112, 97, 99, 101, 34, 44, 34, 116, 121, 112, 101, 34, + 58, 34, 114, 101, 99, 111, 114, 100, 34, 125, 20, 97, 118, 114, 111, 46, 99, 111, 100, + 101, 99, 12, 115, 110, 97, 112, 112, 121, 0, 213, 209, 241, 208, 200, 110, 164, 47, + 203, 25, 90, 235, 161, 167, 195, 177, 2, 20, 4, 12, 6, 49, 50, 51, 115, 38, 58, 0, 213, + 209, 241, 208, 200, 110, 164, 47, 203, 25, 90, 235, 161, 167, 195, 177, + ]; + + if let Err(err) = Reader::new(snappy_compressed_avro.as_slice()) { + pretty_assertions::assert_eq!( + "Codec 'snappy' is not supported/enabled", + err.to_string() + ); + } else { + panic!("Expected an error in the reading of the codec!"); + } + } } diff --git a/avro/src/state_machines/mod.rs b/avro/src/state_machines/mod.rs deleted file mode 100644 index 28157eae..00000000 --- a/avro/src/state_machines/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod reading; diff --git a/avro/src/util.rs b/avro/src/util.rs index 8544362a..d0c3dbf8 100644 --- a/avro/src/util.rs +++ b/avro/src/util.rs @@ -104,7 +104,7 @@ fn encode_variable(mut z: u64, mut writer: W) -> AvroResult { /// enough bytes and does not return the amount of bytes read. /// /// See [`decode_len`] for more details. -pub fn decode_len_simple(buffer: &[u8]) -> AvroResult<(usize, usize)> { +pub(crate) fn decode_len_simple(buffer: &[u8]) -> AvroResult<(usize, usize)> { decode_len(buffer)?.ok_or_else(|| { Details::ReadVariableIntegerBytes(std::io::ErrorKind::UnexpectedEof.into()).into() }) @@ -117,7 +117,7 @@ pub fn decode_len_simple(buffer: &[u8]) -> AvroResult<(usize, usize)> { /// # Returns /// `Some(integer, bytes read)` if it completely read an integer, `None` if it did not have enough /// bytes in the slice. -pub fn decode_len(buffer: &[u8]) -> AvroResult> { +pub(crate) fn decode_len(buffer: &[u8]) -> AvroResult> { if let Some((integer, bytes)) = decode_variable(buffer)? { let length = usize::try_from(integer).map_err(|e| Details::ConvertI64ToUsize(e, integer))?; @@ -133,7 +133,7 @@ pub fn decode_len(buffer: &[u8]) -> AvroResult> { /// # Returns /// `Some(integer, bytes read)` if it completely read an integer, `None` if it did not have enough /// bytes in the slice. -pub fn decode_variable(buffer: &[u8]) -> Result, Error> { +pub(crate) fn decode_variable(buffer: &[u8]) -> Result, Error> { let mut decoded = 0; let mut loops_done = 0; let mut last_byte = 0; From 86173d32eab113db77f6f30533433b9dd5b5534a Mon Sep 17 00:00:00 2001 From: Kriskras99 Date: Thu, 13 Nov 2025 08:48:19 +0100 Subject: [PATCH 7/7] wip: rework decode --- avro/src/decode/block.rs | 1 - avro/src/decode/commands.rs | 5 +- avro/src/decode/mod.rs | 5 +- avro/src/decode2/codec.rs | 182 ++++++ avro/src/decode2/complex/block.rs | 258 ++++++++ avro/src/decode2/complex/mod.rs | 34 ++ avro/src/decode2/complex/record.rs | 85 +++ avro/src/decode2/complex/union.rs | 64 ++ avro/src/decode2/logical/decimal.rs | 80 +++ avro/src/decode2/logical/mod.rs | 3 + avro/src/decode2/logical/time.rs | 116 ++++ avro/src/decode2/logical/uuid.rs | 45 ++ avro/src/decode2/mod.rs | 284 +++++++++ avro/src/decode2/object_container/data.rs | 97 +++ avro/src/decode2/object_container/header.rs | 200 ++++++ avro/src/decode2/object_container/mod.rs | 24 + avro/src/decode2/primitive/bytes.rs | 161 +++++ avro/src/decode2/primitive/floats.rs | 36 ++ avro/src/decode2/primitive/mod.rs | 35 ++ avro/src/decode2/primitive/zigzag.rs | 45 ++ avro/src/encode.rs | 12 +- avro/src/encode2/block.rs | 110 ++++ avro/src/encode2/bytes.rs | 71 +++ avro/src/encode2/codec.rs | 182 ++++++ avro/src/encode2/commands.rs | 646 ++++++++++++++++++++ avro/src/encode2/datum.rs | 80 +++ avro/src/encode2/error.rs | 16 + avro/src/encode2/mod.rs | 161 +++++ avro/src/encode2/object_container_file.rs | 318 ++++++++++ avro/src/encode2/union.rs | 78 +++ avro/src/lib.rs | 1 + avro/src/reader/mod.rs | 3 - avro/src/schema.rs | 41 +- avro/src/schema_compatibility.rs | 8 +- avro/src/schema_equality.rs | 74 ++- avro/src/ser_schema.rs | 13 +- avro/src/types.rs | 21 +- avro/src/util.rs | 36 +- 38 files changed, 3592 insertions(+), 39 deletions(-) create mode 100644 avro/src/decode2/codec.rs create mode 100644 avro/src/decode2/complex/block.rs create mode 100644 avro/src/decode2/complex/mod.rs create mode 100644 avro/src/decode2/complex/record.rs create mode 100644 avro/src/decode2/complex/union.rs create mode 100644 avro/src/decode2/logical/decimal.rs create mode 100644 avro/src/decode2/logical/mod.rs create mode 100644 avro/src/decode2/logical/time.rs create mode 100644 avro/src/decode2/logical/uuid.rs create mode 100644 avro/src/decode2/mod.rs create mode 100644 avro/src/decode2/object_container/data.rs create mode 100644 avro/src/decode2/object_container/header.rs create mode 100644 avro/src/decode2/object_container/mod.rs create mode 100644 avro/src/decode2/primitive/bytes.rs create mode 100644 avro/src/decode2/primitive/floats.rs create mode 100644 avro/src/decode2/primitive/mod.rs create mode 100644 avro/src/decode2/primitive/zigzag.rs create mode 100644 avro/src/encode2/block.rs create mode 100644 avro/src/encode2/bytes.rs create mode 100644 avro/src/encode2/codec.rs create mode 100644 avro/src/encode2/commands.rs create mode 100644 avro/src/encode2/datum.rs create mode 100644 avro/src/encode2/error.rs create mode 100644 avro/src/encode2/object_container_file.rs create mode 100644 avro/src/encode2/union.rs diff --git a/avro/src/decode/block.rs b/avro/src/decode/block.rs index 8e3028fc..271de804 100644 --- a/avro/src/decode/block.rs +++ b/avro/src/decode/block.rs @@ -25,7 +25,6 @@ pub struct BlockStateMachine { impl BlockStateMachine { pub fn new_with_tape(command_tape: CommandTape, tape: Vec) -> Self { Self { - // This clone is *cheap* command_tape, tape_or_fsm: TapeOrFsm::Tape(tape), left_in_current_block: 0, diff --git a/avro/src/decode/commands.rs b/avro/src/decode/commands.rs index fd607057..27380801 100644 --- a/avro/src/decode/commands.rs +++ b/avro/src/decode/commands.rs @@ -385,7 +385,7 @@ impl<'a> CommandTapeBuilder<'a> { self.tape.push(CommandTape::BYTES); Ok(1) } - Schema::String | Schema::Uuid => { + Schema::String | Schema::Uuid(_) => { self.tape.push(CommandTape::STRING); Ok(1) } @@ -512,6 +512,7 @@ impl<'a> CommandTapeBuilder<'a> { #[cfg(test)] mod tests { + use crate::schema::UuidSchema; use super::*; #[test] @@ -636,7 +637,7 @@ mod tests { &[CommandTape::STRING] ); assert_eq!( - CommandTape::build_from_schema(&Schema::Uuid, &HashMap::new()) + CommandTape::build_from_schema(&Schema::Uuid(UuidSchema::String), &HashMap::new()) .unwrap() .inner .as_ref(), diff --git a/avro/src/decode/mod.rs b/avro/src/decode/mod.rs index a2dda7fc..6219ec20 100644 --- a/avro/src/decode/mod.rs +++ b/avro/src/decode/mod.rs @@ -450,7 +450,7 @@ pub fn value_from_tape_internal( } .into()), }, - Schema::Uuid => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { + Schema::Uuid(_) => match tape.next().ok_or(ValueFromTapeError::UnexpectedEndOfTape)? { ItemRead::String(string) => Uuid::from_str(&string) .map(Value::Uuid) .map_err(|e| Details::ConvertStrToUuid(e).into()), @@ -605,6 +605,7 @@ mod tests { use pretty_assertions::assert_eq; use std::collections::HashMap; use uuid::Uuid; + use crate::schema::UuidSchema; #[test] fn test_decode_array_without_size() -> TestResult { @@ -1116,7 +1117,7 @@ mod tests { let mut buffer = Vec::new(); encode(&value, &schema, &mut buffer).expect(&success(&value, &schema)); - let result = from_avro_datum(&Schema::Uuid, &mut &buffer[..], None)?; + let result = from_avro_datum(&Schema::Uuid(UuidSchema::String), &mut &buffer[..], None)?; assert_eq!(result, value); Ok(()) diff --git a/avro/src/decode2/codec.rs b/avro/src/decode2/codec.rs new file mode 100644 index 00000000..b6413f8e --- /dev/null +++ b/avro/src/decode2/codec.rs @@ -0,0 +1,182 @@ +use crate::{ + Codec, + decode2::{Fsm, FsmControlFlow, FsmResult}, +}; +use oval::Buffer; + +pub struct CodecStateMachine { + sub_machine: Option, + codec: Decoder, + buffer: Buffer, +} + +impl CodecStateMachine { + pub fn new(sub_machine: T, codec: Codec) -> Self { + Self { + sub_machine: Some(sub_machine), + codec: codec.into(), + buffer: Buffer::with_capacity(1024), + } + } + + pub fn reset(&mut self, sub_machine: T) { + self.buffer.reset(); + self.sub_machine = Some(sub_machine); + self.codec.reset(); + } +} + +pub enum Decoder { + Null, + Deflate(Box), + #[cfg(feature = "snappy")] + Snappy(snap::raw::Decoder), + #[cfg(feature = "zstandard")] + Zstandard(zstd::stream::raw::Decoder<'static>), + #[cfg(feature = "bzip")] + Bzip2(bzip2::Decompress), + #[cfg(feature = "xz")] + Xz(liblzma::stream::Stream), +} + +impl From for Decoder { + fn from(value: Codec) -> Self { + match value { + Codec::Null => Self::Null, + Codec::Deflate(_) => { + use miniz_oxide::{DataFormat::Raw, inflate::stream::InflateState}; + Self::Deflate(InflateState::new_boxed(Raw)) + } + #[cfg(feature = "snappy")] + Codec::Snappy => Self::Snappy(snap::raw::Decoder::new()), + #[cfg(feature = "zstandard")] + Codec::Zstandard(_) => Self::Zstandard(zstd::stream::raw::Decoder::new().unwrap()), + #[cfg(feature = "bzip")] + Codec::Bzip2(_) => Self::Bzip2(bzip2::Decompress::new(false)), + #[cfg(feature = "xz")] + Codec::Xz(_) => { + Self::Xz(liblzma::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap()) + } + } + } +} + +impl Decoder { + pub fn reset(&mut self) { + match self { + Decoder::Null => {} + Decoder::Deflate(decoder) => { + decoder.reset_as(miniz_oxide::inflate::stream::MinReset); + } + #[cfg(feature = "snappy")] + Decoder::Snappy(_decoder) => {} // No reset needed + #[cfg(feature = "zstandard")] + Decoder::Zstandard(decoder) => zstd::stream::raw::Operation::reinit(decoder).unwrap(), + #[cfg(feature = "bzip")] + Decoder::Bzip2(decoder) => { + // No reset/reinit API available + let _drop = std::mem::replace(decoder, bzip2::Decompress::new(false)); + } + #[cfg(feature = "xz")] + Decoder::Xz(decoder) => { + // No reset/reinit API available + let _drop = std::mem::replace( + decoder, + liblzma::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap(), + ); + } + } + } +} + +impl Fsm for CodecStateMachine { + type Output = (T::Output, Self); + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + let buffer = match &mut self.codec { + Decoder::Null => buffer, + Decoder::Deflate(decoder) => { + use miniz_oxide::{MZFlush, StreamResult, inflate::stream::inflate}; + let StreamResult { + bytes_consumed, + bytes_written, + status, + } = inflate(decoder, buffer.data(), self.buffer.space(), MZFlush::None); + status.unwrap(); + buffer.consume(bytes_consumed); + self.buffer.fill(bytes_written); + + &mut self.buffer + } + #[cfg(feature = "snappy")] + Decoder::Snappy(_decoder) => { + todo!("Snap has no streaming decoder") + } + #[cfg(feature = "zstandard")] + Decoder::Zstandard(decoder) => { + use zstd::stream::raw::{Operation, Status}; + let Status { + bytes_read, + bytes_written, + .. + } = decoder + .run_on_buffers(buffer.data(), self.buffer.space()) + .map_err(crate::error::Details::ZstdDecompress)?; + buffer.consume(bytes_read); + self.buffer.fill(bytes_written); + + &mut self.buffer + } + #[cfg(feature = "bzip")] + Decoder::Bzip2(decoder) => { + let prev_total_in = decoder.total_in(); + let prev_total_out = decoder.total_out(); + + let _status = decoder + .decompress(buffer.data(), self.buffer.space()) + .unwrap(); + + let consumed = decoder.total_in() - prev_total_in; + let filled = decoder.total_out() - prev_total_out; + + buffer.consume(usize::try_from(consumed).unwrap()); + self.buffer.fill(usize::try_from(filled).unwrap()); + + &mut self.buffer + } + #[cfg(feature = "xz")] + Decoder::Xz(decoder) => { + use liblzma::stream::Action::Run; + + let prev_total_in = decoder.total_in(); + let prev_total_out = decoder.total_out(); + + let _status = decoder + .process(buffer.data(), self.buffer.space(), Run) + .unwrap(); + + let consumed = decoder.total_in() - prev_total_in; + let filled = decoder.total_out() - prev_total_out; + + buffer.consume(usize::try_from(consumed).unwrap()); + self.buffer.fill(usize::try_from(filled).unwrap()); + + &mut self.buffer + } + }; + match self + .sub_machine + .take() + .expect("CodecStateMachine was not reset!") + .parse(buffer)? + { + FsmControlFlow::NeedMore(fsm) => { + self.sub_machine = Some(fsm); + Ok(FsmControlFlow::NeedMore(self)) + } + FsmControlFlow::Done(result) => { + Ok(FsmControlFlow::Done((result, self))) + } + } + } +} diff --git a/avro/src/decode2/complex/block.rs b/avro/src/decode2/complex/block.rs new file mode 100644 index 00000000..e0725bf6 --- /dev/null +++ b/avro/src/decode2/complex/block.rs @@ -0,0 +1,258 @@ +use std::collections::HashMap; +use oval::Buffer; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult, SubFsm}; +use crate::decode2::primitive::bytes::StringFsm; +use crate::decode2::decode_zigzag_buffer; +use crate::error::Details; +use crate::Schema; +use crate::types::Value; + +pub struct ArrayFsm<'a> { + schema: &'a Schema, + sub_fsm: Option>>, + left_in_current_block: usize, + need_to_read_block_byte_size: bool, + values: Vec, +} +impl<'a> ArrayFsm<'a> { + pub fn new(schema: &'a Schema) -> Self { + Self { + schema, + sub_fsm: None, + left_in_current_block: 0, + need_to_read_block_byte_size: false, + values: Vec::new(), + } + } +} +impl<'a> Fsm for ArrayFsm<'a> { + type Output = Value; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + // We loop until we're finished or we need more data + loop { + // If we finished the last block (or are newly created) read the block info + if self.left_in_current_block == 0 { + let Some(block) = decode_zigzag_buffer(buffer)? else { + return Ok(FsmControlFlow::NeedMore(self)); + }; + + // Done parsing the blocks + if block == 0 { + return Ok(FsmControlFlow::Done(Value::Array(self.values))); + } + + // Need to read the block byte size when block is negative + self.need_to_read_block_byte_size = block.is_negative(); + + // We do the rest with the absolute block size + let abs_block = usize::try_from(block.unsigned_abs()) + .map_err(|e| Details::ConvertU64ToUsize(e, block.unsigned_abs()))?; + + self.left_in_current_block = abs_block; + self.values.reserve(abs_block); + } + + // If the block length was negative we need to read the block size + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(FsmControlFlow::NeedMore(self)); + }; + + // Make sure the value is sane + // TODO: Maybe use safe_len here? + let _ = usize::try_from(block) + .map_err(|e| Details::ConvertI64ToUsize(e, block))?; + + // This is not necessary, as it will be overwritten before being read again + // but it does show the intent more clearly + self.need_to_read_block_byte_size = false; + } + + // Check if we already have a state machine to run + if let Some(sub_fsm) = self.sub_fsm.as_deref_mut() { + let fsm = std::mem::take(sub_fsm); + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + let _ = std::mem::replace(sub_fsm, fsm); + return Ok(FsmControlFlow::NeedMore(self)); + } + FsmControlFlow::Done(value) => { + self.left_in_current_block -= 1; + self.values.push(value); + // Assume we need to read another value + // We do this to reuse the Box we already have and therefore preventing a lot + // of allocations + let _ = std::mem::replace(sub_fsm, SubFsm::from(self.schema)); + // Continue the loop + continue; + } + } + } else { + let fsm = SubFsm::from(self.schema); + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + // Save the current progress and ask for more bytes + self.sub_fsm = Some(Box::new(fsm)); + return Ok(FsmControlFlow::NeedMore(self)); + } + FsmControlFlow::Done(value) => { + self.left_in_current_block -= 1; + self.values.push(value); + // As we don't have a box yet, we don't create the next state machine as that + // allocation might be unnecessary + continue; + } + } + } + } + } +} + +pub struct MapFsm<'a> { + schema: &'a Schema, + sub_fsm: Option>>, + left_in_current_block: usize, + need_to_read_block_byte_size: bool, + current_key: Option, + values: HashMap, +} +impl<'a> MapFsm<'a> { + pub fn new(schema: &'a Schema) -> Self { + Self { + schema, + sub_fsm: None, + left_in_current_block: 0, + need_to_read_block_byte_size: false, + current_key: None, + values: HashMap::new(), + } + } +} +impl<'a> Fsm for MapFsm<'a> { + type Output = Value; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + // We loop until we're finished or we need more data + loop { + // If we finished the last block (or are newly created) read the block info + if self.left_in_current_block == 0 { + let Some(block) = decode_zigzag_buffer(buffer)? else { + return Ok(FsmControlFlow::NeedMore(self)); + }; + + // Done parsing the blocks + if block == 0 { + return Ok(FsmControlFlow::Done(Value::Map(self.values))); + } + + // Need to read the block byte size when block is negative + self.need_to_read_block_byte_size = block.is_negative(); + + // We do the rest with the absolute block size + let abs_block = usize::try_from(block.unsigned_abs()) + .map_err(|e| Details::ConvertU64ToUsize(e, block.unsigned_abs()))?; + + self.left_in_current_block = abs_block; + self.values.reserve(abs_block); + } + + // If the block length was negative we need to read the block size + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(FsmControlFlow::NeedMore(self)); + }; + + // Make sure the value is sane + // TODO: Maybe use safe_len here? + let _ = usize::try_from(block) + .map_err(|e| Details::ConvertI64ToUsize(e, block))?; + + // This is not necessary, as it will be overwritten before being read again + // but it does show the intent more clearly + self.need_to_read_block_byte_size = false; + } + + if let Some(key) = self.current_key.take() { + // Check if we already have a state machine to run + if let Some(sub_fsm) = self.sub_fsm.as_deref_mut() { + let fsm = std::mem::take(sub_fsm); + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + let _ = std::mem::replace(sub_fsm, fsm); + self.current_key = Some(key); + return Ok(FsmControlFlow::NeedMore(self)); + } + FsmControlFlow::Done(value) => { + self.left_in_current_block -= 1; + self.values.insert(key, value); + // Assume we need to read another key + // We do this to reuse the Box we already have and therefore preventing a lot + // of allocations + let _ = std::mem::replace(sub_fsm, SubFsm::String(StringFsm::default())); + // Continue the loop + continue; + } + } + } else { + let fsm = SubFsm::from(self.schema); + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + // Save the current progress and ask for more bytes + self.sub_fsm = Some(Box::new(fsm)); + self.current_key = Some(key); + return Ok(FsmControlFlow::NeedMore(self)); + } + FsmControlFlow::Done(value) => { + self.left_in_current_block -= 1; + self.values.insert(key, value); + // As we don't have a box yet, we don't create the next state machine as that + // allocation might be unnecessary + continue; + } + } + } + } else { + // Check if we already have a state machine to run + if let Some(sub_fsm) = self.sub_fsm.as_deref_mut() { + let fsm = std::mem::take(sub_fsm); + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + let _ = std::mem::replace(sub_fsm, fsm); + return Ok(FsmControlFlow::NeedMore(self)); + } + FsmControlFlow::Done(value) => { + let Value::String(key) = value else { + unreachable!() + }; + self.current_key = Some(key); + // Now we need to read the value, might as well reuse the allocation + let _ = std::mem::replace(sub_fsm, SubFsm::from(self.schema)); + // Continue the loop + continue; + } + } + } else { + let fsm = SubFsm::String(StringFsm::default()); + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + // Save the current progress and ask for more bytes + self.sub_fsm = Some(Box::new(fsm)); + return Ok(FsmControlFlow::NeedMore(self)); + } + FsmControlFlow::Done(value) => { + let Value::String(key) = value else { + unreachable!() + }; + self.current_key = Some(key); + // We don't have an allocation so there is nothing to reuse + continue; + } + } + } + } + } + } +} \ No newline at end of file diff --git a/avro/src/decode2/complex/mod.rs b/avro/src/decode2/complex/mod.rs new file mode 100644 index 00000000..b404764d --- /dev/null +++ b/avro/src/decode2/complex/mod.rs @@ -0,0 +1,34 @@ +use oval::Buffer; +use crate::decode2::{decode_zigzag_buffer, Fsm, FsmControlFlow, FsmResult}; +use crate::error::Details; +use crate::schema::EnumSchema; +use crate::types::Value; + +pub mod union; +pub mod block; +pub mod record; + +pub struct EnumFsm<'a> { + schema: &'a EnumSchema, +} +impl<'a> EnumFsm<'a> { + pub fn new(schema: &'a EnumSchema) -> Self { + Self { schema } + } +} +impl<'a> Fsm for EnumFsm<'a> { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + let Some(n) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(FsmControlFlow::NeedMore(self)); + }; + let n = u32::try_from(n).map_err(|_| Details::GetEnumUnknownIndexValue)?; + // If we truncate the value with `as usize` instead of try_from we might get a valid index + // value. + let n_as_usize = usize::try_from(n).map_err(|_| Details::GetEnumUnknownIndexValue)?; + let symbol = self.schema.symbols.get(n_as_usize).cloned().ok_or(Details::GetEnumUnknownIndexValue)?; + Ok(FsmControlFlow::Done(Value::Enum(n, symbol))) + } +} diff --git a/avro/src/decode2/complex/record.rs b/avro/src/decode2/complex/record.rs new file mode 100644 index 00000000..69f1d14b --- /dev/null +++ b/avro/src/decode2/complex/record.rs @@ -0,0 +1,85 @@ +use oval::Buffer; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult, SubFsm}; +use crate::schema::RecordSchema; +use crate::types::Value; + +pub struct RecordFsm<'a> { + schema: &'a RecordSchema, + current_field: usize, + sub_fsm: Option>>, + fields: Vec<(String, Value)>, +} +impl<'a> RecordFsm<'a> { + pub fn new(schema: &'a RecordSchema) -> Self { + Self { + schema, + current_field: 0, + sub_fsm: None, + fields: Vec::new(), + } + } +} + +impl<'a> Fsm for RecordFsm<'a> { + type Output = Value; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + // All fields are there, this should only be possible for an empty record. Which should + // not exist + if self.current_field >= self.schema.fields.len() { + return Ok(FsmControlFlow::Done(Value::Record(self.fields))); + } + loop { + if let Some(sub_fsm) = self.sub_fsm.as_deref_mut() { + let fsm = std::mem::take(sub_fsm); + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + let _ = std::mem::replace(sub_fsm, fsm); + return Ok(FsmControlFlow::NeedMore(self)); + } + FsmControlFlow::Done(value) => { + // Finished reading a field, add the name and value to the list + let field_name = self.schema.fields[self.current_field].name.clone(); + self.fields.push((field_name, value)); + assert_eq!(self.current_field, self.fields.len() - 1); + + self.current_field += 1; + + // If there is a next field, prepare the state machine in the same box + if let Some(field) = self.schema.fields.get(self.current_field) { + let _ = std::mem::replace(sub_fsm, SubFsm::from(&field.schema)); + // Restart the loop + continue; + } else { + assert_eq!(self.fields.len(), self.schema.fields.len()); + return Ok(FsmControlFlow::Done(Value::Record(self.fields))); + } + } + } + } else { + let schema = &self.schema.fields[self.current_field].schema; + match SubFsm::from(schema).parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + self.sub_fsm = Some(Box::new(fsm)); + return Ok(FsmControlFlow::NeedMore(self)); + } + FsmControlFlow::Done(value) => { + // Finished reading a field, add the name and value to the list + let field_name = self.schema.fields[self.current_field].name.clone(); + self.fields.push((field_name, value)); + assert_eq!(self.current_field, self.fields.len() - 1); + + self.current_field += 1; + + // If there is no next field, return the record + if self.schema.fields.get(self.current_field).is_none() { + assert_eq!(self.fields.len(), self.schema.fields.len()); + return Ok(FsmControlFlow::Done(Value::Record(self.fields))); + } + } + } + + } + } + } +} \ No newline at end of file diff --git a/avro/src/decode2/complex/union.rs b/avro/src/decode2/complex/union.rs new file mode 100644 index 00000000..d44e864e --- /dev/null +++ b/avro/src/decode2/complex/union.rs @@ -0,0 +1,64 @@ +use oval::Buffer; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult, SubFsm}; +use crate::decode2::decode_zigzag_buffer; +use crate::error::Details; +use crate::schema::UnionSchema; +use crate::types::Value; + +pub struct UnionFsm<'a> { + schema: &'a UnionSchema, + index: Option, + sub_fsm: Option>> +} + +impl<'a> UnionFsm<'a> { + pub fn new(schema: &'a UnionSchema) -> UnionFsm<'a> { + Self { + schema, + index: None, + sub_fsm: None + } + } +} + +impl<'a> Fsm for UnionFsm<'a> { + type Output = Value; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + if self.index.is_none() { + let Some(n) = decode_zigzag_buffer(buffer)? else { + return Ok(FsmControlFlow::NeedMore(self)); + }; + // TODO: proper error + self.index = Some(u32::try_from(n).map_err(|e| Details::EmptyUnion)?); + } + let index = self.index.unwrap_or_else(|| unreachable!()); + if let Some(sub_fsm) = self.sub_fsm.as_deref_mut() { + let fsm = std::mem::take(sub_fsm); + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + let _ = std::mem::replace(sub_fsm, fsm); + Ok(FsmControlFlow::NeedMore(self)) + } + FsmControlFlow::Done(value) => { + Ok(FsmControlFlow::Done(Value::Union(index, Box::new(value)))) + } + } + } else { + // TODO: proper error, GetUnionVariants + let index_usize = usize::try_from(index).map_err(|e| Details::EmptyUnion)?; + let schema = self.schema.schemas.get(index_usize).ok_or(Details::EmptyUnion)?; + + match SubFsm::from(schema).parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + self.sub_fsm = Some(Box::new(fsm)); + Ok(FsmControlFlow::NeedMore(self)) + } + FsmControlFlow::Done(value) => { + Ok(FsmControlFlow::Done(Value::Union(index, Box::new(value)))) + } + } + } + } +} + diff --git a/avro/src/decode2/logical/decimal.rs b/avro/src/decode2/logical/decimal.rs new file mode 100644 index 00000000..c038bdac --- /dev/null +++ b/avro/src/decode2/logical/decimal.rs @@ -0,0 +1,80 @@ +use oval::Buffer; +use crate::decode2::primitive::bytes::{BytesFsm, FixedFsm}; +use crate::decode2::{FsmControlFlow, Fsm, FsmResult}; +use crate::{Decimal, Schema}; +use crate::bigdecimal::deserialize_big_decimal; +use crate::schema::DecimalSchema; +use crate::types::Value; + +#[derive(Default)] +pub struct BigDecimalFsm(BytesFsm); + +impl Fsm for BigDecimalFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + self.0.parse(buffer)?.map_fallible(|fsm| Ok(Self(fsm)), |bytes| { + let Value::Bytes(bytes) = bytes else { unreachable!() }; + Ok(Value::BigDecimal(deserialize_big_decimal(&bytes)?)) + }) + } +} + +enum BytesOrFixedFsm { + Bytes(BytesFsm), + Fixed(FixedFsm), +} +impl Fsm for BytesOrFixedFsm { + type Output = Vec; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + match self { + BytesOrFixedFsm::Bytes(fsm) => { + Ok(fsm.parse(buffer)?.map(BytesOrFixedFsm::Bytes, |v| { + let Value::Bytes(bytes) = v else { unreachable!() }; + bytes + })) + } + BytesOrFixedFsm::Fixed(fsm) => { + Ok(fsm.parse(buffer)?.map(BytesOrFixedFsm::Fixed, |v| { + let Value::Fixed(_, bytes) = v else { unreachable!() }; + bytes + })) + } + } + } +} + +pub struct DecimalFsm { + fsm: BytesOrFixedFsm, +} + +impl DecimalFsm { + pub fn new(schema: &DecimalSchema) -> Self { + let fsm = if let Schema::Fixed(fixed) = schema.inner.as_ref() { + BytesOrFixedFsm::Fixed(FixedFsm::new(fixed.size)) + } else if let Schema::Bytes = schema.inner.as_ref() { + BytesOrFixedFsm::Bytes(BytesFsm::default()) + } else { + panic!("Invalid DecimalSchema, inner schema is not Fixed or Bytes"); + }; + Self { + fsm, + } + } +} +impl Fsm for DecimalFsm { + type Output = Value; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + match self.fsm.parse(buffer)? { + FsmControlFlow::NeedMore(new_fsm) => { + self.fsm = new_fsm; + Ok(FsmControlFlow::NeedMore(self)) + } + FsmControlFlow::Done(bytes) => { + Ok(FsmControlFlow::Done(Value::Decimal(Decimal::from(bytes)))) + } + } + } +} diff --git a/avro/src/decode2/logical/mod.rs b/avro/src/decode2/logical/mod.rs new file mode 100644 index 00000000..8e6816a5 --- /dev/null +++ b/avro/src/decode2/logical/mod.rs @@ -0,0 +1,3 @@ +pub mod time; +pub mod decimal; +pub mod uuid; \ No newline at end of file diff --git a/avro/src/decode2/logical/time.rs b/avro/src/decode2/logical/time.rs new file mode 100644 index 00000000..1f5a6db7 --- /dev/null +++ b/avro/src/decode2/logical/time.rs @@ -0,0 +1,116 @@ +use std::ops::Deref; +use oval::Buffer; +use crate::decode2::{Fsm, FsmResult}; +use crate::decode2::primitive::bytes::FixedFsm; +use crate::decode2::primitive::zigzag::ZigZagFSM; +use crate::Duration; +use crate::error::Details; +use crate::types::Value; + +pub struct DurationFsm(FixedFsm); +impl DurationFsm { + pub fn new() -> Self { + Self(FixedFsm::new(12)) + } +} +impl Default for DurationFsm { + fn default() -> Self { + Self::new() + } +} +impl Fsm for DurationFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, |v| { + let Value::Fixed(_, bytes) = v else { unreachable!() }; + let array: [u8; 12] = bytes.deref().try_into().unwrap(); + Value::Duration(Duration::from(array)) + })) + } +} + +#[derive(Default)] +pub struct DateFsm(ZigZagFSM); +impl Fsm for DateFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + self.0.parse(buffer)?.map_fallible(|fsm| Ok(Self(fsm)), |n| { + let n = i32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + Ok(Value::Date(n)) + }) + } +} + +#[derive(Default)] +pub struct TimeMillisFsm(ZigZagFSM); +impl Fsm for TimeMillisFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + self.0.parse(buffer)?.map_fallible(|fsm| Ok(Self(fsm)), |n| { + let n = i32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + Ok(Value::TimeMillis(n)) + }) + } +} + +#[derive(Default)] +pub struct TimeMicrosFsm(ZigZagFSM); +impl Fsm for TimeMicrosFsm { + type Output = Value; + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, Value::TimeMicros)) + } +} + +#[derive(Default)] +pub struct TimestampMillisFsm(ZigZagFSM); +impl Fsm for TimestampMillisFsm { + type Output = Value; + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, Value::TimestampMillis)) + } +} +#[derive(Default)] +pub struct TimestampMicrosFsm(ZigZagFSM); +impl Fsm for TimestampMicrosFsm { + type Output = Value; + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, Value::TimestampMicros)) + } +} + +#[derive(Default)] +pub struct TimestampNanosFsm(ZigZagFSM); +impl Fsm for TimestampNanosFsm { + type Output = Value; + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, Value::TimestampNanos)) + } +} +#[derive(Default)] +pub struct LocalTimestampMillisFsm(ZigZagFSM); +impl Fsm for LocalTimestampMillisFsm { + type Output = Value; + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, Value::LocalTimestampMillis)) + } +} +#[derive(Default)] +pub struct LocalTimestampMicrosFsm(ZigZagFSM); +impl Fsm for LocalTimestampMicrosFsm { + type Output = Value; + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, Value::LocalTimestampMicros)) + } +} +#[derive(Default)] +pub struct LocalTimestampNanosFsm(ZigZagFSM); +impl Fsm for LocalTimestampNanosFsm { + type Output = Value; + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, Value::LocalTimestampNanos)) + } +} \ No newline at end of file diff --git a/avro/src/decode2/logical/uuid.rs b/avro/src/decode2/logical/uuid.rs new file mode 100644 index 00000000..4f3dce31 --- /dev/null +++ b/avro/src/decode2/logical/uuid.rs @@ -0,0 +1,45 @@ +use oval::Buffer; +use uuid::Uuid; +use crate::decode2::primitive::bytes::{FixedFsm, StringFsm}; +use crate::decode2::{Fsm, FsmResult}; +use crate::error::Details; +use crate::schema::UuidSchema; +use crate::types::Value; + +pub enum UuidFsm { + String(StringFsm), + FixedFSM(FixedFsm), +} +impl UuidFsm { + pub fn new(schema: &UuidSchema) -> Self { + match schema { + UuidSchema::String => { + Self::String(StringFsm::default()) + } + UuidSchema::Fixed(fixed_schema) => { + assert_eq!(fixed_schema.size, 16, "Uuid(Fixed) must be 16 bytes"); + Self::FixedFSM(FixedFsm::new(16)) + } + } + } +} +impl Fsm for UuidFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + match self { + UuidFsm::String(fsm) => { + fsm.parse(buffer)?.map_fallible(|fsm| Ok(Self::String(fsm)), |v| { + let Value::String(string) = v else { unreachable!() }; + Ok(Value::Uuid(Uuid::parse_str(&string).map_err(Details::UuidFromSlice)?)) + }) + } + UuidFsm::FixedFSM(fsm) => { + fsm.parse(buffer)?.map_fallible(|fsm| Ok(Self::FixedFSM(fsm)), |v| { + let Value::Bytes(bytes) = v else { unreachable!() }; + Ok(Value::Uuid(Uuid::from_slice(&bytes).map_err(Details::UuidFromSlice)?)) + }) + } + } + } +} diff --git a/avro/src/decode2/mod.rs b/avro/src/decode2/mod.rs new file mode 100644 index 00000000..2c91caeb --- /dev/null +++ b/avro/src/decode2/mod.rs @@ -0,0 +1,284 @@ +//! Low-level decoders for binary Avro data. +//! +//! # Low-level decoders +//! +//! This module contains the low-level decoders for binary Avro data. It is strongly recommended to +//! use the high-level readers in [`reader`]. This should only be used if you are decoding a custom +//! file format or are using an async runtime not supported in [`reader`]. +//! +//! # Usage +//! The decoder implementations are based on finite state machines, expressed through the [`Fsm`] +//! trait. A decoder is supplied with a [`Buffer`] that has some (but not necessarily all) data. The +//! decoder will decode as far as it can. If it does not have enough data, it will return +//! [`FsmControlFlow::NeedMore`]. By filling the buffer with more data, the decoder can be resumed. +//! If the end of the data stream is reached and the decoder is still returning [`FsmControlFlow::NeedMore`] +//! then the data stream is corrupt and an error should be thrown. +//! +//! If the decoder is finished decoding, it will return a [`FsmControlFlow::Done`] which will also +//! consume the decoder, preventing accidental reuse of an already finished decoder. +//! +//! The supplied buffer must have space for at least 16 bytes, any less and some decoders won't be +//! able to make progress. +//! +//! ```no_run +//! # use apache_avro::{Schema, decode2::{DatumFsm, Fsm, FsmControlFlow}}; +//! # use oval::Buffer; +//! # use std::{fs::File, io::Read}; +//! # let schema = Schema::Bytes; +//! +//! let mut buffer = Buffer::with_capacity(256); +//! let mut file = File::open("some").unwrap(); +//! let mut fsm = DatumFsm::new(&schema); +//! +//! // If the schema is Schema::Null, it's perfectly valid for the file to be empty +//! // so we don't check the amount of bytes filled the first time. +//! let bytes_read = file.read(buffer.space()).unwrap(); +//! buffer.fill(bytes_read); +//! +//! let result = loop { +//! match fsm.parse(&mut buffer).unwrap() { +//! FsmControlFlow::NeedMore(new_fsm) => { +//! fsm = new_fsm; +//! let bytes_read = file.read(buffer.space()).unwrap(); +//! if bytes_read == 0 { +//! panic!("File is finished but decoder is not"); +//! } +//! buffer.fill(bytes_read); +//! } +//! FsmControlFlow::Done(value) => { +//! break value; +//! } +//! } +//! }; +//! println!("{result:?}"); +//! ``` +//! +//! +//! [`reader`]: crate::reader + +/// Decoder for compressed Avro data. +mod codec; +/// Decoder for the Object Container Files. +pub mod object_container; +/// Decoders for primitive types (and `fixed`). +mod primitive; +/// Decoders for complex types (excluding `fixed`). +mod complex; +/// Decoders for logical types. +mod logical; + +use oval::Buffer; +use complex::union::UnionFsm; +use crate::{Error, Schema}; +use complex::block::{ArrayFsm, MapFsm}; +use primitive::bytes::{BytesFsm, FixedFsm, StringFsm}; +use logical::decimal::{BigDecimalFsm, DecimalFsm}; +use primitive::floats::{DoubleFsm, FloatFsm}; +use complex::record::RecordFsm; +use logical::time::{DateFsm, DurationFsm, LocalTimestampMicrosFsm, LocalTimestampMillisFsm, LocalTimestampNanosFsm, TimeMicrosFsm, TimeMillisFsm, TimestampMicrosFsm, TimestampMillisFsm, TimestampNanosFsm}; +use logical::uuid::UuidFsm; +use primitive::zigzag::{IntFsm, LongFsm}; +use crate::decode2::complex::EnumFsm; +use crate::decode2::primitive::{BoolFsm, NullFsm}; +use crate::schema::{ArraySchema, FixedSchema, MapSchema}; +use crate::types::Value; +use crate::util::decode_variable; + +/// Read a zigzagged varint from the buffer. +/// +/// Will only consume the buffer if a whole number has been read. +/// If insufficient bytes are available it will return `Ok(None)` to +/// indicate it needs more bytes. +fn decode_zigzag_buffer(buffer: &mut Buffer) -> Result, Error> { + if let Some((decoded, consumed)) = decode_variable(buffer.data())? { + buffer.consume(consumed); + Ok(Some(decoded)) + } else { + Ok(None) + } +} + +/// A trait for the lifecycle of a finite state machine. +pub trait Fsm: Sized { + /// The final output of the state machine. + type Output: Sized; + + /// Start/continue the state machine. + /// + /// Implementers are not allowed to return until they can't make progress anymore. + fn parse(self, buffer: &mut Buffer) -> FsmResult; +} + +/// Indicates whether the state machine has completed or needs to be polled again. +#[must_use] +pub enum FsmControlFlow { + /// The state machine needs more data before it can continue. + NeedMore(Fsm), + /// The state machine is done and the result is returned. + Done(Output), +} + +impl FsmControlFlow { + /// Map a state machine to another state machine. + /// + /// This function will only execute `need_more` or `done`, not both. + pub fn map(self, need_more: F1, done: F2) -> FsmControlFlow + where F1: FnOnce(FSM1) -> FSM2, + F2: FnOnce(O1) -> O2, { + match self { + FsmControlFlow::NeedMore(fsm) => { + FsmControlFlow::NeedMore(need_more(fsm)) + } + FsmControlFlow::Done(fsm) => { + FsmControlFlow::Done(done(fsm)) + } + } + } + + pub fn map_fallible(self, need_more: F1, done: F2) -> Result, Error> + where F1: FnOnce(FSM1) -> Result, + F2: FnOnce(O1) -> Result, { + match self { + FsmControlFlow::NeedMore(fsm) => { + Ok(FsmControlFlow::NeedMore(need_more(fsm)?)) + } + FsmControlFlow::Done(fsm) => { + Ok(FsmControlFlow::Done(done(fsm)?)) + } + } + } +} + +pub type FsmResult = Result, Error>; + +pub struct DatumFsm<'a> { + /// We wrap around inner to hide implementation details from the user. + fsm: SubFsm<'a> +} +impl<'a> DatumFsm<'a> { + pub fn new(schema: &'a Schema) -> Self { + Self { + fsm: SubFsm::from(schema) + } + } + + pub fn new_with_schemata(schema: &'a Schema, schemata: Vec<&'a Schema>) -> Self { + todo!() + } +} +impl<'a> Fsm for DatumFsm<'a> { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.fsm.parse(buffer)?.map(|fsm| Self { fsm }, |v| v)) + } +} + +enum SubFsm<'a> { + Null(NullFsm), + Boolean(BoolFsm), + Int(IntFsm), + Long(LongFsm), + Float(FloatFsm), + Double(DoubleFsm), + Bytes(BytesFsm), + String(StringFsm), + Fixed(FixedFsm), + Enum(EnumFsm<'a>), + Union(UnionFsm<'a>), + Array(ArrayFsm<'a>), + Map(MapFsm<'a>), + Record(RecordFsm<'a>), + Date(DateFsm), + Decimal(DecimalFsm), + BigDecimal(BigDecimalFsm), + TimeMillis(TimeMillisFsm), + TimeMicros(TimeMicrosFsm), + TimestampMillis(TimestampMillisFsm), + TimestampMicros(TimestampMicrosFsm), + TimestampNanos(TimestampNanosFsm), + LocalTimestampMillis(LocalTimestampMillisFsm), + LocalTimestampMicros(LocalTimestampMicrosFsm), + LocalTimestampNanos(LocalTimestampNanosFsm), + Duration(DurationFsm), + Uuid(UuidFsm), +} + +impl<'a> Default for SubFsm<'a> { + fn default() -> Self { + Self::Null(NullFsm) + } +} + +impl<'a> Fsm for SubFsm<'a> { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + match self { + Self::Null(fsm) => Ok(fsm.parse(buffer)?.map(Self::Null, |v| v)), + Self::Boolean(fsm) => Ok(fsm.parse(buffer)?.map(Self::Boolean, |v| v)), + Self::Int(fsm) => Ok(fsm.parse(buffer)?.map(Self::Int, |v| v)), + Self::Long(fsm) => Ok(fsm.parse(buffer)?.map(Self::Long, |v| v)), + Self::Float(fsm) => Ok(fsm.parse(buffer)?.map(Self::Float, |v| v)), + Self::Double(fsm) => Ok(fsm.parse(buffer)?.map(Self::Double, |v| v)), + Self::Bytes(fsm) => Ok(fsm.parse(buffer)?.map(Self::Bytes, |v| v)), + Self::String(fsm) => Ok(fsm.parse(buffer)?.map(Self::String, |v| v)), + Self::Fixed(fsm) => Ok(fsm.parse(buffer)?.map(Self::Fixed, |v| v)), + Self::Enum(fsm) => Ok(fsm.parse(buffer)?.map(Self::Enum, |v| v)), + Self::Union(fsm) => Ok(fsm.parse(buffer)?.map(Self::Union, |v| v)), + Self::Array(fsm) => Ok(fsm.parse(buffer)?.map(Self::Array, |v| v)), + Self::Map(fsm) => Ok(fsm.parse(buffer)?.map(Self::Map, |v| v)), + Self::Record(fsm) => Ok(fsm.parse(buffer)?.map(Self::Record, |v| v)), + Self::Date(fsm) => Ok(fsm.parse(buffer)?.map(Self::Date, |v| v)), + Self::Decimal(fsm) => Ok(fsm.parse(buffer)?.map(Self::Decimal, |v| v)), + Self::BigDecimal(fsm) => Ok(fsm.parse(buffer)?.map(Self::BigDecimal, |v| v)), + Self::TimeMillis(fsm) => Ok(fsm.parse(buffer)?.map(Self::TimeMillis, |v| v)), + Self::TimeMicros(fsm) => Ok(fsm.parse(buffer)?.map(Self::TimeMicros, |v| v)), + Self::TimestampMillis(fsm) => Ok(fsm.parse(buffer)?.map(Self::TimestampMillis, |v| v)), + Self::TimestampMicros(fsm) => Ok(fsm.parse(buffer)?.map(Self::TimestampMicros, |v| v)), + Self::TimestampNanos(fsm) => Ok(fsm.parse(buffer)?.map(Self::TimestampNanos, |v| v)), + Self::LocalTimestampMillis(fsm) => Ok(fsm.parse(buffer)?.map(Self::LocalTimestampMillis, |v| v)), + Self::LocalTimestampMicros(fsm) => Ok(fsm.parse(buffer)?.map(Self::LocalTimestampMicros, |v| v)), + Self::LocalTimestampNanos(fsm) => Ok(fsm.parse(buffer)?.map(Self::LocalTimestampNanos, |v| v)), + Self::Duration(fsm) => Ok(fsm.parse(buffer)?.map(Self::Duration, |v| v)), + Self::Uuid(fsm) => Ok(fsm.parse(buffer)?.map(Self::Uuid, |v| v)), + } + } +} + +impl<'a> From<&'a Schema> for SubFsm<'a> { + fn from(value: &'a Schema) -> Self { + match value { + Schema::Null => Self::Null(NullFsm), + Schema::Boolean => Self::Boolean(BoolFsm), + Schema::Int => Self::Int(IntFsm::default()), + Schema::Long => Self::Long(LongFsm::default()), + Schema::Float => Self::Float(FloatFsm), + Schema::Double => Self::Double(DoubleFsm), + Schema::Bytes => Self::Bytes(BytesFsm::default()), + Schema::String => Self::String(StringFsm::default()), + Schema::Array(ArraySchema { items, .. }) => Self::Array(ArrayFsm::new(items)), + Schema::Map(MapSchema { types, .. }) => Self::Map(MapFsm::new(types)), + Schema::Union(schema) => Self::Union(UnionFsm::new(schema)), + Schema::Record(schema) => Self::Record(RecordFsm::new(schema)), + Schema::Enum(schema) => Self::Enum(EnumFsm::new(schema)), + Schema::Fixed(FixedSchema { size, .. }) => Self::Fixed(FixedFsm::new(*size)), + Schema::Decimal(schema) => Self::Decimal(DecimalFsm::new(schema)), + Schema::BigDecimal => Self::BigDecimal(BigDecimalFsm::default()), + Schema::Uuid(schema) => Self::Uuid(UuidFsm::new(schema)), + Schema::Date => Self::Date(DateFsm::default()), + Schema::TimeMillis => Self::TimeMillis(TimeMillisFsm::default()), + Schema::TimeMicros => Self::TimeMicros(TimeMicrosFsm::default()), + Schema::TimestampMillis => Self::TimestampMillis(TimestampMillisFsm::default()), + Schema::TimestampMicros => Self::TimestampMicros(TimestampMicrosFsm::default()), + Schema::TimestampNanos => Self::TimestampNanos(TimestampNanosFsm::default()), + Schema::LocalTimestampMillis => Self::LocalTimestampMillis(LocalTimestampMillisFsm::default()), + Schema::LocalTimestampMicros => Self::LocalTimestampMicros(LocalTimestampMicrosFsm::default()), + Schema::LocalTimestampNanos => Self::LocalTimestampNanos(LocalTimestampNanosFsm::default()), + Schema::Duration => Self::Duration(DurationFsm::default()), + Schema::Ref { .. } => todo!() + } + } +} + + diff --git a/avro/src/decode2/object_container/data.rs b/avro/src/decode2/object_container/data.rs new file mode 100644 index 00000000..bf1a156e --- /dev/null +++ b/avro/src/decode2/object_container/data.rs @@ -0,0 +1,97 @@ +use std::io::Read; +use oval::Buffer; +use crate::decode2::object_container::Header; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult, SubFsm}; +use crate::decode2::codec::CodecStateMachine; +use crate::decode2::{decode_zigzag_buffer}; +use crate::error::Details; +use crate::Schema; +use crate::types::Value; + +pub struct DataBlockFsm<'a> { + schema: &'a Schema, + fsm: CodecStateMachine>, + sync: [u8; 16], + left_in_block: usize, + need_to_read_block_byte_size: bool, + need_to_read_sync: bool, +} +impl<'a> DataBlockFsm<'a> { + pub fn new(header: &'a Header) -> Self { + let fsm = CodecStateMachine::new(SubFsm::from(&header.schema), header.codec); + Self { + schema: &header.schema, + fsm, + sync: header.sync, + left_in_block: 0, + need_to_read_block_byte_size: false, + need_to_read_sync: false, + } + } +} +impl<'a> Fsm for DataBlockFsm<'a> { + type Output = Option<(Value, Self)>; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + // If we have just finished a block or have just been created we need to read the block metadata + if self.left_in_block == 0 { + // At the end of a block we need to read the sync marker + if self.need_to_read_sync { + if buffer.available_data() < 16 { + return Ok(FsmControlFlow::NeedMore(self)); + } + + let mut sync = [0; 16]; + buffer.read_exact(&mut sync).unwrap_or_else(|_| unreachable!()); + + if sync != self.sync { + return Err(Details::GetBlockMarker.into()); + } + self.need_to_read_sync = false; + } + + // Read the amount of items in the block + let Some(block) = decode_zigzag_buffer(buffer)? else { + return Ok(FsmControlFlow::NeedMore(self)); + }; + + let abs_block = block.unsigned_abs(); + let abs_block = + usize::try_from(abs_block).map_err(|e| Details::ConvertU64ToUsize(e, abs_block))?; + if abs_block == 0 { + // Finished reading the file + return Ok(FsmControlFlow::Done(None)); + } + self.need_to_read_block_byte_size = true; + // This will only be done after left_in_block hits 0 + self.need_to_read_sync = true; + self.left_in_block = abs_block; + } + + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(FsmControlFlow::NeedMore(self)); + }; + // We can't use the size, but we should check that it is valid + let _size = usize::try_from(block).map_err(|e| Details::ConvertI64ToUsize(e, block))?; + self.need_to_read_block_byte_size = false; + } + + match self.fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + self.fsm = fsm; + Ok(FsmControlFlow::NeedMore(self)) + } + FsmControlFlow::Done((value, fsm)) => { + self.left_in_block -= 1; + + // Codec's inner FSM is finished so needs to be reset + self.fsm = fsm; + self.fsm.reset(SubFsm::from(self.schema)); + + Ok(FsmControlFlow::Done(Some((value, self)))) + } + } + } +} \ No newline at end of file diff --git a/avro/src/decode2/object_container/header.rs b/avro/src/decode2/object_container/header.rs new file mode 100644 index 00000000..c836cd39 --- /dev/null +++ b/avro/src/decode2/object_container/header.rs @@ -0,0 +1,200 @@ + +use std::collections::HashMap; +use std::io::Read; +use std::str::FromStr; +use log::warn; +use oval::Buffer; +use crate::{Codec, Error, Schema}; +use crate::decode2::complex::block::MapFsm; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult}; +use crate::decode2::object_container::Header; +use crate::error::Details; +use crate::schema::{resolve_names, resolve_names_with_schemata, Names, ResolvedSchema}; +use crate::types::Value; + +/// Decode the header of an Object Container File. +/// +/// The output of this state machine can be used with [`DataBlockFsm`] to parse the body. +/// +/// [`DataBlockFsm`]: super::DataBlockFsm +pub struct HeaderFsm<'a> { + /// We wrap around inner to hide implementation details from the user. + fsm: InnerHeaderFsm<'a>, +} +impl<'a> HeaderFsm<'a> { + /// Create a new decoder. + pub fn new() -> Self { + Self { + fsm: InnerHeaderFsm::new(), + } + } + + /// Create a new decoder with schemata. + pub fn new_with_schemata(schemata: Vec<&'a Schema>) -> Self { + Self { + fsm: InnerHeaderFsm::new_with_schemata(schemata) + } + } +} +impl Default for HeaderFsm<'_> { + fn default() -> Self { + Self::new() + } +} +impl<'a> Fsm for HeaderFsm<'a> { + type Output = Header; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.fsm.parse(buffer)?.map(|fsm| Self { fsm }, |h| h)) + } +} + +/// The actual decoder for the header. +enum InnerHeaderFsm<'a> { + /// We start with reading the magic number and verifying this is an Object Container File. + ReadMagic { + schemata: Vec<&'a Schema> + }, + /// Then we read in all the metadata. + Metadata { + fsm: MapFsm<'static>, + schemata: Vec<&'a Schema>, + }, + /// Finally, we need to read the sync marker. + Sync { + metadata: HashMap, + schemata: Vec<&'a Schema>, + }, +} +impl<'a> InnerHeaderFsm<'a> { + pub fn new() -> Self { + Self::ReadMagic { schemata: Vec::new() } + } + + pub fn new_with_schemata(schemata: Vec<&'a Schema>) -> Self { + Self::ReadMagic { schemata } + } +} + +impl<'a> Fsm for InnerHeaderFsm<'a> { + type Output = Header; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + loop { + match self { + InnerHeaderFsm::ReadMagic { schemata } => { + if buffer.available_data() < 4 { + return Ok(FsmControlFlow::NeedMore(Self::ReadMagic { schemata })); + } + if buffer.data()[0..4] != [b'O', b'b', b'j', 1] { + return Err(Details::HeaderMagic.into()); + } + buffer.consume(4); + self = Self::Metadata { fsm: MapFsm::new(&Schema::Bytes), schemata }; + } + InnerHeaderFsm::Metadata { fsm, schemata } => { + match fsm.parse(buffer)? { + FsmControlFlow::NeedMore(fsm) => { + return Ok(FsmControlFlow::NeedMore(Self::Metadata { fsm, schemata })); + } + FsmControlFlow::Done(value) => { + let Value::Map(metadata) = value else { unreachable!() }; + self = Self::Sync { metadata, schemata}; + } + } + } + InnerHeaderFsm::Sync { metadata, schemata } => { + if buffer.available_data() < 16 { + return Ok(FsmControlFlow::NeedMore(Self::Sync { metadata, schemata })); + } + let mut sync = [0; 16]; + buffer.read_exact(&mut sync).unwrap_or_else(|_| unreachable!()); + return Ok(FsmControlFlow::Done(create_header(metadata, sync, schemata)?)) + } + } + + } + } +} + +fn create_header(map: HashMap, sync: [u8; 16], mut schemata: Vec<&Schema>) -> Result { + let mut schema = None; + let mut codec = None; + let mut found_compression_level = false; + let mut metadata = HashMap::new(); + let mut names = HashMap::new(); + + for (key, value) in map { + let Value::Bytes(value) = value else { unreachable!() }; + match key.as_ref() { + "avro.schema" => { + if schema.is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + let json: serde_json::Value = + serde_json::from_slice(&value).map_err(Details::ParseSchemaJson)?; + + if !schemata.is_empty() { + // TODO: Make parse_with_names accept NamesRef + let schemata = std::mem::take(&mut schemata); + resolve_names_with_schemata(&schemata, &mut names, &None)?; + + // TODO: Maybe we can not do this, and just past &names to Schema::parse_with_names + let rs = ResolvedSchema::try_from(schemata)?; + let names: Names = rs + .get_names() + .iter() + .map(|(name, &schema)| (name.clone(), schema.clone())) + .collect(); + + let parsed_schema = Schema::parse_with_names(&json, names)?; + schema.replace(parsed_schema); + } else { + let parsed_schema = Schema::parse(&json)?; + resolve_names(&parsed_schema, &mut names, &None)?; + schema.replace(parsed_schema); + } + } + "avro.codec" => { + let string = String::from_utf8(value).map_err(Details::ConvertToUtf8)?; + let parsed_codec = Codec::from_str(&string) + .map_err(|_| Details::CodecNotSupported(string))?; + if codec.replace(parsed_codec).is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + } + "avro.codec.compression_level" => { + // Compression level is not useful for decoding + if found_compression_level { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + found_compression_level = true; + } + _ => { + if key.starts_with("avro.") { + warn!("Ignoring unknown metadata key: {key}"); + } + if metadata.insert(key, value).is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + } + } + } + + + let Some(schema) = schema else { + return Err(Details::GetHeaderMetadata.into()); + }; + let codec = codec.unwrap_or(Codec::Null); + Ok(Header { + schema, + names, + codec, + sync, + metadata, + }) +} diff --git a/avro/src/decode2/object_container/mod.rs b/avro/src/decode2/object_container/mod.rs new file mode 100644 index 00000000..22064d44 --- /dev/null +++ b/avro/src/decode2/object_container/mod.rs @@ -0,0 +1,24 @@ +use std::collections::HashMap; +use crate::{Codec, Schema}; +use crate::schema::Names; + +/// Decoder for the Object Container File header. +mod header; +/// Decoder for the Object Container File data blocks. +mod data; + +/// The header as read from an Object Container File. +pub struct Header { + /// The schema used to write the file. + pub schema: Schema, + pub names: Names, + /// The compression used. + pub codec: Codec, + /// The sync marker used between blocks + pub sync: [u8; 16], + /// User metadata in the header + pub metadata: HashMap>, +} + +pub use header::HeaderFsm; +pub use data::DataBlockFsm; diff --git a/avro/src/decode2/primitive/bytes.rs b/avro/src/decode2/primitive/bytes.rs new file mode 100644 index 00000000..9ebeb784 --- /dev/null +++ b/avro/src/decode2/primitive/bytes.rs @@ -0,0 +1,161 @@ +use oval::Buffer; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult}; +use crate::decode2::decode_zigzag_buffer; +use crate::Error; +use crate::error::Details; +use crate::types::Value; + +#[derive(Default)] +pub struct BytesFsm { + inner: InnerBytesFSM, +} + +impl Fsm for BytesFsm { + type Output = Value; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + match self.inner.parse(buffer)? { + FsmControlFlow::NeedMore(new_innner) => { + self.inner = new_innner; + Ok(FsmControlFlow::NeedMore(self)) + } + FsmControlFlow::Done(bytes) => { + Ok(FsmControlFlow::Done(Value::Bytes(bytes))) + } + } + } +} + +pub struct FixedFsm { + inner: InnerBytesFSM, +} + +impl FixedFsm { + pub fn new(length: usize) -> Self { + Self { + inner: InnerBytesFSM::new_with_length(length), + } + } +} + +impl Fsm for FixedFsm { + type Output = Value; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + match self.inner.parse(buffer)? { + FsmControlFlow::NeedMore(new_innner) => { + self.inner = new_innner; + Ok(FsmControlFlow::NeedMore(self)) + } + FsmControlFlow::Done(bytes) => { + Ok(FsmControlFlow::Done(Value::Fixed(bytes.len(), bytes))) + } + } + } +} + +#[derive(Default)] +pub struct StringFsm { + inner: InnerBytesFSM, + validated_up_to: usize, +} + +impl StringFsm { + /// Validate that the partially read string is valid UTF-8. + /// + /// This will only validate the part of the string that has not been validated yet. + /// + /// If `incomplete` is `true` this will allow the last 3 bytes of the string to be invalid. + /// When the full string is read, `incomplete` should be set to `false`, which will also validate + /// the last 3 bytes. + fn partial_validate(data: &[u8], mut validated_up_to: usize, incomplete: bool) -> Result { + let unvalidated = &data[validated_up_to..]; + match std::str::from_utf8(unvalidated) { + Ok(_) => { + validated_up_to = data.len(); + Ok(validated_up_to) + }, + Err(error) => { + validated_up_to += error.valid_up_to(); + if incomplete && validated_up_to + 3 >= data.len() { + Ok(validated_up_to) + } else { + Err(Details::ConvertToUtf8Error(error).into()) + } + } + } + } +} + +impl Fsm for StringFsm { + type Output = Value; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + match self.inner.parse(buffer)? { + FsmControlFlow::NeedMore(new_innner) => { + self.inner = new_innner; + // Validate the part that was just read + self.validated_up_to = Self::partial_validate(&self.inner.data, self.validated_up_to, true)?; + Ok(FsmControlFlow::NeedMore(self)) + } + FsmControlFlow::Done(bytes) => { + // Validate the last bit + Self::partial_validate(&bytes, self.validated_up_to, false)?; + // SAFETY: bytes is valid, as it has been incrementally checked during read + let string = unsafe { + String::from_utf8_unchecked(bytes) + }; + Ok(FsmControlFlow::Done(Value::String(string))) + } + } + } +} + +#[derive(Default)] +struct InnerBytesFSM { + length: Option, + data: Vec, +} + +impl InnerBytesFSM { + pub fn new_with_length(length: usize) -> Self { + Self { + length: Some(length), + data: Vec::with_capacity(length), + } + } +} + +impl Fsm for InnerBytesFSM { + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> FsmResult { + if self.length.is_none() { + let Some(length) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer varint byte plus we know + // there at least 127 bytes in the buffer now (as otherwise we wouldn't need one more varint byte). + return Ok(FsmControlFlow::NeedMore(self)); + }; + let length = + usize::try_from(length).map_err(|e| Details::ConvertI64ToUsize(e, length))?; + self.length = Some(length); + self.data.reserve_exact(length); + } + // This was just set in the previous if statement and it returns if that was not possible to do. + let Some(length) = self.length else { + unreachable!() + }; + + // How much more data is needed + let remaining = length - self.data.len(); + // How much of that is available in the buffer + let available = remaining.min(buffer.available_data()); + self.data.extend_from_slice(&buffer.data()[..available]); + buffer.consume(available); + if remaining - available == 0 { + Ok(FsmControlFlow::Done(self.data)) + } else { + Ok(FsmControlFlow::NeedMore(self)) + } + } +} \ No newline at end of file diff --git a/avro/src/decode2/primitive/floats.rs b/avro/src/decode2/primitive/floats.rs new file mode 100644 index 00000000..ec9545e6 --- /dev/null +++ b/avro/src/decode2/primitive/floats.rs @@ -0,0 +1,36 @@ +use std::io::Read; +use oval::Buffer; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult}; +use crate::types::Value; + +pub struct FloatFsm; +impl Fsm for FloatFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + if buffer.available_data() < 4 { + Ok(FsmControlFlow::NeedMore(self)) + } else { + let mut bytes = [0; 4]; + buffer.read_exact(&mut bytes).unwrap_or_else(|_| unreachable!()); + let float = f32::from_le_bytes(bytes); + Ok(FsmControlFlow::Done(Value::Float(float))) + } + } +} + +pub struct DoubleFsm; +impl Fsm for DoubleFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + if buffer.available_data() < 8 { + Ok(FsmControlFlow::NeedMore(self)) + } else { + let mut bytes = [0; 8]; + buffer.read_exact(&mut bytes).unwrap_or_else(|_| unreachable!()); + let double = f64::from_le_bytes(bytes); + Ok(FsmControlFlow::Done(Value::Double(double))) + } + } +} \ No newline at end of file diff --git a/avro/src/decode2/primitive/mod.rs b/avro/src/decode2/primitive/mod.rs new file mode 100644 index 00000000..6a468620 --- /dev/null +++ b/avro/src/decode2/primitive/mod.rs @@ -0,0 +1,35 @@ +use std::io::Read; +use oval::Buffer; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult}; +use crate::error::Details; +use crate::types::Value; + +pub mod zigzag; +pub mod floats; +pub mod bytes; + +pub struct NullFsm; +impl Fsm for NullFsm { + type Output = Value; + + fn parse(self, _buffer: &mut Buffer) -> FsmResult { + Ok(FsmControlFlow::Done(Value::Null)) + } +} + +pub struct BoolFsm; +impl Fsm for BoolFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + let mut byte = [0; 1]; + buffer + .read_exact(&mut byte) + .expect("Unreachable! Buffer is not empty"); + match byte { + [0] => Ok(FsmControlFlow::Done(Value::Boolean(false))), + [1] => Ok(FsmControlFlow::Done(Value::Boolean(true))), + [byte] => Err(Details::BoolValue(byte).into()), + } + } +} diff --git a/avro/src/decode2/primitive/zigzag.rs b/avro/src/decode2/primitive/zigzag.rs new file mode 100644 index 00000000..cbb700ff --- /dev/null +++ b/avro/src/decode2/primitive/zigzag.rs @@ -0,0 +1,45 @@ +use oval::Buffer; +use crate::decode2::{Fsm, FsmControlFlow, FsmResult}; +use crate::error::Details; +use crate::types::Value; +use crate::util::decode_variable; + +#[derive(Default)] +pub struct IntFsm(ZigZagFSM); + +impl Fsm for IntFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + self.0.parse(buffer)?.map_fallible(|fsm| Ok(Self(fsm)), |n| { + let n = i32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + Ok(Value::Int(n)) + }) + } +} + +#[derive(Default)] +pub struct LongFsm(ZigZagFSM); + +impl Fsm for LongFsm { + type Output = Value; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + Ok(self.0.parse(buffer)?.map(Self, Value::Long)) + } +} + +#[derive(Default)] +pub struct ZigZagFSM; +impl Fsm for ZigZagFSM { + type Output = i64; + + fn parse(self, buffer: &mut Buffer) -> FsmResult { + if let Some((decoded, consumed)) = decode_variable(buffer.data())? { + buffer.consume(consumed); + Ok(FsmControlFlow::Done(decoded)) + } else { + Ok(FsmControlFlow::NeedMore(self)) + } + } +} \ No newline at end of file diff --git a/avro/src/encode.rs b/avro/src/encode.rs index 38203f9d..7ec65fd3 100644 --- a/avro/src/encode.rs +++ b/avro/src/encode.rs @@ -28,6 +28,7 @@ use crate::{ }; use log::error; use std::{borrow::Borrow, collections::HashMap, io::Write}; +use crate::schema::UuidSchema; /// Encode a `Value` into avro format. /// @@ -39,6 +40,9 @@ pub fn encode(value: &Value, schema: &Schema, writer: &mut W) -> AvroR encode_internal(value, schema, rs.get_names(), &None, writer) } +/// Encode `s` as the _bytes_ primitive type. +/// +/// This writes the length as the _long_ primitive and then the bytes. pub(crate) fn encode_bytes + ?Sized, W: Write>( s: &B, mut writer: W, @@ -133,13 +137,13 @@ pub(crate) fn encode_internal>( .map_err(|e| Details::WriteBytes(e).into()) } Value::Uuid(uuid) => match *schema { - Schema::Uuid | Schema::String => encode_bytes( + Schema::Uuid(UuidSchema::String) | Schema::String => encode_bytes( // we need the call .to_string() to properly convert ASCII to UTF-8 #[allow(clippy::unnecessary_to_owned)] &uuid.to_string(), writer, ), - Schema::Fixed(FixedSchema { size, .. }) => { + Schema::Uuid(UuidSchema::Fixed(FixedSchema { size, .. })) | Schema::Fixed(FixedSchema { size, .. }) => { if size != 16 { return Err(Details::ConvertFixedToUuid(size).into()); } @@ -171,7 +175,7 @@ pub(crate) fn encode_internal>( .into()), }, Value::String(s) => match *schema { - Schema::String | Schema::Uuid => encode_bytes(s, writer), + Schema::String | Schema::Uuid(UuidSchema::String) => encode_bytes(s, writer), Schema::Enum(EnumSchema { ref symbols, .. }) => { if let Some(index) = symbols.iter().position(|item| item == s) { encode_int(index as i32, writer) @@ -934,7 +938,7 @@ pub(crate) mod tests { #[test] fn test_avro_3585_encode_uuids() { let value = Value::String(String::from("00000000-0000-0000-0000-000000000000")); - let schema = Schema::Uuid; + let schema = Schema::Uuid(UuidSchema::String); let mut buffer = Vec::new(); let encoded = encode(&value, &schema, &mut buffer); assert!(encoded.is_ok()); diff --git a/avro/src/encode2/block.rs b/avro/src/encode2/block.rs new file mode 100644 index 00000000..6c135f61 --- /dev/null +++ b/avro/src/encode2/block.rs @@ -0,0 +1,110 @@ +use oval::Buffer; + +use crate::{ + Error, + encode2::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, datum::DatumStateMachine, + decode_zigzag_buffer, + }, + error::Details, +}; + +/// Are we currently parsing an object or just finished/reading a block header +enum TapeOrFsm { + Tape(Vec), + Fsm(DatumStateMachine), +} + +pub struct BlockStateMachine { + command_tape: CommandTape, + tape_or_fsm: TapeOrFsm, + left_in_current_block: usize, + need_to_read_block_byte_size: bool, +} + +impl BlockStateMachine { + pub fn new_with_tape(command_tape: CommandTape, tape: Vec) -> Self { + Self { + // This clone is *cheap* + command_tape, + tape_or_fsm: TapeOrFsm::Tape(tape), + left_in_current_block: 0, + need_to_read_block_byte_size: false, + } + } +} + +impl StateMachine for BlockStateMachine { + type Output = Vec; + fn parse( + mut self, + buffer: &mut Buffer, + ) -> Result, Error> { + loop { + match self.tape_or_fsm { + TapeOrFsm::Tape(mut tape) => { + // If we finished the last block (or are newly created) read the block info + if self.left_in_current_block == 0 { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.tape_or_fsm = TapeOrFsm::Tape(tape); + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + + // Need to read the block byte size when block is negative + self.need_to_read_block_byte_size = block.is_negative(); + + // We do the rest with the absolute block size + let abs_block = usize::try_from(block.unsigned_abs()) + .map_err(|e| Details::ConvertU64ToUsize(e, block.unsigned_abs()))?; + self.left_in_current_block = abs_block; + tape.push(ItemRead::Block(abs_block)); + + // Done parsing the blocks + if abs_block == 0 { + return Ok(StateMachineControlFlow::Done(tape)); + } + } + + // If the block length was negative we need to read the block size + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.tape_or_fsm = TapeOrFsm::Tape(tape); + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + + // Make sure the value is sane + // TODO: Maybe use safe_len here? + let _ = usize::try_from(block) + .map_err(|e| Details::ConvertI64ToUsize(e, block))?; + + // This is not necessary, as it will be overwritten before being read again + // but it does show the intent more clearly + self.need_to_read_block_byte_size = false; + } + + // We've either finished reading the block header or the last object was read and + // left_in_current_block is not zero + self.tape_or_fsm = TapeOrFsm::Fsm(DatumStateMachine::new_with_tape( + self.command_tape.clone(), + tape, + )) + } + TapeOrFsm::Fsm(fsm) => { + // (Continue) reading the object + match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(fsm); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(tape) => { + self.tape_or_fsm = TapeOrFsm::Tape(tape); + self.left_in_current_block -= 1; + } + } + } + } + } + } +} diff --git a/avro/src/encode2/bytes.rs b/avro/src/encode2/bytes.rs new file mode 100644 index 00000000..8a4fb587 --- /dev/null +++ b/avro/src/encode2/bytes.rs @@ -0,0 +1,71 @@ +use oval::Buffer; + +use crate::{ + encode2::{StateMachine, StateMachineControlFlow, decode_zigzag_buffer}, + error::Details, +}; + +use super::StateMachineResult; + +// TODO: Also make a String specific state machine. This allows checking the utf-8 while parsing +// which would make the parser fail quicker on large invalid strings. +// TODO: This state machine could also produce inline strings (smolstr) for strings smaller than +// size_of::, and use some extra bits to store well-known strings +// like avro.schema and avro.codec as fixed strings. + +#[derive(Default)] +pub struct BytesStateMachine { + length: Option, + data: Vec, +} + +impl BytesStateMachine { + pub fn new() -> Self { + Self { + length: None, + data: Vec::new(), + } + } + + pub fn new_with_length(length: usize) -> Self { + Self { + length: Some(length), + data: Vec::with_capacity(length), + } + } +} + +impl StateMachine for BytesStateMachine { + // This is a Vec instead of a Box<[u8]> as it's easier to create a string from a vec + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + if self.length.is_none() { + let Some(length) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer varint byte plus we know + // there at least 127 bytes in the buffer now (as otherwise we wouldn't need one more varint byte). + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let length = + usize::try_from(length).map_err(|e| Details::ConvertI64ToUsize(e, length))?; + self.length = Some(length); + self.data.reserve_exact(length); + } + // This was just set in the previous if statement and it returns if that was not possible to do. + let Some(length) = self.length else { + unreachable!() + }; + + // How much more data is needed + let remaining = length - self.data.len(); + // How much of that is available in the buffer + let available = remaining.min(buffer.available_data()); + self.data.extend_from_slice(&buffer.data()[..available]); + buffer.consume(available); + if remaining - available == 0 { + Ok(StateMachineControlFlow::Done(self.data)) + } else { + Ok(StateMachineControlFlow::NeedMore(self)) + } + } +} diff --git a/avro/src/encode2/codec.rs b/avro/src/encode2/codec.rs new file mode 100644 index 00000000..01dfb30d --- /dev/null +++ b/avro/src/encode2/codec.rs @@ -0,0 +1,182 @@ +use crate::{ + Codec, + encode2::{StateMachine, StateMachineControlFlow, StateMachineResult}, +}; +use oval::Buffer; + +pub struct CodecStateMachine { + sub_machine: Option, + codec: Decoder, + buffer: Buffer, +} + +impl CodecStateMachine { + pub fn new(sub_machine: T, codec: Codec) -> Self { + Self { + sub_machine: Some(sub_machine), + codec: codec.into(), + buffer: Buffer::with_capacity(1024), + } + } + + pub fn reset(&mut self, sub_machine: T) { + self.buffer.reset(); + self.sub_machine = Some(sub_machine); + self.codec.reset(); + } +} + +pub enum Decoder { + Null, + Deflate(Box), + #[cfg(feature = "snappy")] + Snappy(snap::raw::Decoder), + #[cfg(feature = "zstandard")] + Zstandard(zstd::stream::raw::Decoder<'static>), + #[cfg(feature = "bzip")] + Bzip2(bzip2::Decompress), + #[cfg(feature = "xz")] + Xz(liblzma::stream::Stream), +} + +impl From for Decoder { + fn from(value: Codec) -> Self { + match value { + Codec::Null => Self::Null, + Codec::Deflate(_) => { + use miniz_oxide::{DataFormat::Raw, inflate::stream::InflateState}; + Self::Deflate(InflateState::new_boxed(Raw)) + } + #[cfg(feature = "snappy")] + Codec::Snappy => Self::Snappy(snap::raw::Decoder::new()), + #[cfg(feature = "zstandard")] + Codec::Zstandard(_) => Self::Zstandard(zstd::stream::raw::Decoder::new().unwrap()), + #[cfg(feature = "bzip")] + Codec::Bzip2(_) => Self::Bzip2(bzip2::Decompress::new(false)), + #[cfg(feature = "xz")] + Codec::Xz(_) => { + Self::Xz(liblzma::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap()) + } + } + } +} + +impl Decoder { + pub fn reset(&mut self) { + match self { + Decoder::Null => {} + Decoder::Deflate(decoder) => { + decoder.reset_as(miniz_oxide::inflate::stream::MinReset); + } + #[cfg(feature = "snappy")] + Decoder::Snappy(_decoder) => {} // No reset needed + #[cfg(feature = "zstandard")] + Decoder::Zstandard(decoder) => zstd::stream::raw::Operation::reinit(decoder).unwrap(), + #[cfg(feature = "bzip")] + Decoder::Bzip2(decoder) => { + // No reset/reinit API available + let _drop = std::mem::replace(decoder, bzip2::Decompress::new(false)); + } + #[cfg(feature = "xz")] + Decoder::Xz(decoder) => { + // No reset/reinit API available + let _drop = std::mem::replace( + decoder, + liblzma::stream::Stream::new_auto_decoder(u64::MAX, 0).unwrap(), + ); + } + } + } +} + +impl StateMachine for CodecStateMachine { + type Output = (T::Output, Self); + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + let buffer = match &mut self.codec { + Decoder::Null => buffer, + Decoder::Deflate(decoder) => { + use miniz_oxide::{MZFlush, StreamResult, inflate::stream::inflate}; + let StreamResult { + bytes_consumed, + bytes_written, + status, + } = inflate(decoder, buffer.data(), self.buffer.space(), MZFlush::None); + status.unwrap(); + buffer.consume(bytes_consumed); + self.buffer.fill(bytes_written); + + &mut self.buffer + } + #[cfg(feature = "snappy")] + Decoder::Snappy(_decoder) => { + todo!("Snap has no streaming decoder") + } + #[cfg(feature = "zstandard")] + Decoder::Zstandard(decoder) => { + use zstd::stream::raw::{Operation, Status}; + let Status { + bytes_read, + bytes_written, + .. + } = decoder + .run_on_buffers(buffer.data(), self.buffer.space()) + .map_err(crate::error::Details::ZstdDecompress)?; + buffer.consume(bytes_read); + self.buffer.fill(bytes_written); + + &mut self.buffer + } + #[cfg(feature = "bzip")] + Decoder::Bzip2(decoder) => { + let prev_total_in = decoder.total_in(); + let prev_total_out = decoder.total_out(); + + let _status = decoder + .decompress(buffer.data(), self.buffer.space()) + .unwrap(); + + let consumed = decoder.total_in() - prev_total_in; + let filled = decoder.total_out() - prev_total_out; + + buffer.consume(usize::try_from(consumed).unwrap()); + self.buffer.fill(usize::try_from(filled).unwrap()); + + &mut self.buffer + } + #[cfg(feature = "xz")] + Decoder::Xz(decoder) => { + use liblzma::stream::Action::Run; + + let prev_total_in = decoder.total_in(); + let prev_total_out = decoder.total_out(); + + let _status = decoder + .process(buffer.data(), self.buffer.space(), Run) + .unwrap(); + + let consumed = decoder.total_in() - prev_total_in; + let filled = decoder.total_out() - prev_total_out; + + buffer.consume(usize::try_from(consumed).unwrap()); + self.buffer.fill(usize::try_from(filled).unwrap()); + + &mut self.buffer + } + }; + match self + .sub_machine + .take() + .expect("CodecStateMachine was not reset!") + .parse(buffer)? + { + StateMachineControlFlow::NeedMore(fsm) => { + self.sub_machine = Some(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(result) => { + Ok(StateMachineControlFlow::Done((result, self))) + } + } + } +} diff --git a/avro/src/encode2/commands.rs b/avro/src/encode2/commands.rs new file mode 100644 index 00000000..5d97217a --- /dev/null +++ b/avro/src/encode2/commands.rs @@ -0,0 +1,646 @@ +use crate::{ + Error, Schema, + encode2::{ + ItemRead, SubStateMachine, block::BlockStateMachine, bytes::BytesStateMachine, + datum::DatumStateMachine, union::UnionStateMachine, + }, + error::Details, + schema::{ + ArraySchema, DecimalSchema, EnumSchema, FixedSchema, MapSchema, Name, Names, RecordSchema, + UnionSchema, + }, +}; +use std::{collections::HashMap, ops::Range, sync::Arc}; + +/// The next item type that should be read. +#[must_use] +pub enum ToRead { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, + Enum, + Ref(CommandTape), + Fixed(usize), + Block(CommandTape), + Union { + variants: CommandTape, + num_variants: usize, + }, +} + +impl ToRead { + pub fn into_state_machine(self, read: Vec) -> SubStateMachine { + match self { + ToRead::Null => SubStateMachine::Null(read), + ToRead::Boolean => SubStateMachine::Bool(read), + ToRead::Int => SubStateMachine::Int(read), + ToRead::Long => SubStateMachine::Long(read), + ToRead::Float => SubStateMachine::Float(read), + ToRead::Double => SubStateMachine::Double(read), + ToRead::Enum => SubStateMachine::Enum(read), + ToRead::Bytes => SubStateMachine::Bytes { + fsm: BytesStateMachine::new(), + read, + }, + ToRead::String => SubStateMachine::String { + fsm: BytesStateMachine::new(), + read, + }, + ToRead::Fixed(length) => SubStateMachine::Bytes { + fsm: BytesStateMachine::new_with_length(length), + read, + }, + ToRead::Ref(commands) => { + SubStateMachine::Object(DatumStateMachine::new_with_tape(commands, read)) + } + ToRead::Block(commands) => { + SubStateMachine::Block(BlockStateMachine::new_with_tape(commands, read)) + } + ToRead::Union { + variants, + num_variants, + } => SubStateMachine::Union(UnionStateMachine::new_with_tape( + variants, + num_variants, + read, + )), + } + } +} + +impl std::fmt::Debug for ToRead { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Null => write!(f, "Null"), + Self::Boolean => write!(f, "Boolean"), + Self::Int => write!(f, "Int"), + Self::Long => write!(f, "Long"), + Self::Float => write!(f, "Float"), + Self::Double => write!(f, "Double"), + Self::Bytes => write!(f, "Bytes"), + Self::String => write!(f, "String"), + Self::Enum => write!(f, "Enum"), + // We don't show the Ref command as that could recurse forever + Self::Ref(_) => write!(f, "Ref<...>"), + Self::Fixed(arg0) => write!(f, "Fixed<{arg0}>"), + Self::Block(arg0) => f.debug_tuple("Block").field(arg0).finish(), + Self::Union { variants, .. } => f.debug_tuple("Union").field(variants).finish(), + } + } +} + +/// A section of a tape of commands. +/// +/// This has a reference to the entire tape, so that references to types (for Union,Map,Array) can be resolved. +#[derive(Clone, PartialEq)] +#[must_use] +pub struct CommandTape { + inner: Arc<[u8]>, + read_range: Range, +} + +impl CommandTape { + pub const NULL: u8 = 0; + pub const BOOLEAN: u8 = 1; + pub const INT: u8 = 2; + pub const LONG: u8 = 3; + pub const FLOAT: u8 = 4; + pub const DOUBLE: u8 = 5; + pub const BYTES: u8 = 6; + pub const STRING: u8 = 7; + pub const ENUM: u8 = 8; + /// A fixed amount of bytes. + /// + /// If the amount of bytes is smaller than or equal to `0xF`, the amount is stored in the four + /// most significant bits of the byte. Otherwise, it's stored as a native endian usize directly + /// after the command byte. + pub const FIXED: u8 = 9; + /// A block based format follows (i.e. Map or Array). + /// + /// The command sequence of the type in the block follows immediately after the command byte. + /// The length of the sequence is stored in the most significant four bits of the command byte. + /// If the sequence is larger than `0xF`, then either the entire sequence or part of it is + /// put behind a [`Self::REF`]. + pub const BLOCK: u8 = 10; + pub const UNION: u8 = 11; + /// A reference to a command sequence somewhere else in the tape. + /// + /// If the length of the sequence is smaller than or equal to `0xF`, the length is stored in the + /// four most significant bits of the byte. Otherwise, it's stored as a native endian usize + /// directly after the command byte. After the length follows the offset as a native endian + /// usize. + pub const REF: u8 = 12; + /// Skip the next `n` commands. + /// + /// A SKIP command is not counted as a command. + /// + /// If `n` is smaller than or equal to `0xF`, the amount is stored in the four most significant + /// bits of the byte. Otherwise, it's stored as a native endian usize directly after the command + /// byte. + pub const SKIP: u8 = 13; + + /// Create a new tape that will be read from start to end. + pub fn new(command_tape: Arc<[u8]>) -> Self { + let length = command_tape.len(); + Self { + inner: command_tape, + read_range: 0..length, + } + } + + pub fn build_from_schema(schema: &Schema, names: &Names) -> Result { + CommandTapeBuilder::build(schema, names) + } + + /// Check if the section of the tape we're reading is finished. + pub fn is_finished(&self) -> bool { + self.read_range.is_empty() + } + + /// Extract a part from the tape to give to a sub-state machine. + /// + /// The tape will run from offset for the given amount of commands. + pub fn extract(&self, offset: usize, commands: usize) -> Self { + let mut temp = Self { + inner: self.inner.clone(), + read_range: offset..self.inner.len(), + }; + temp.skip(commands); + let max_index = temp.read_range.next().unwrap_or(self.inner.len()); + + assert!( + max_index <= self.inner.len(), + "Reference is (partly) outside the tape" + ); + Self { + inner: self.inner.clone(), + read_range: offset..max_index, + } + } + + /// Extract many parts from the tape to give to the Union state machine. + /// + /// The tapes will run from start to end (inclusive). + pub fn extract_many(&self, parts: &[(usize, usize)]) -> Box<[Self]> { + let mut vec = Vec::with_capacity(parts.len()); + for &(start, end) in parts { + vec.push(self.extract(start, end)); + } + vec.into_boxed_slice() + } + + /// Read an array of bytes from the tape. + fn read_array(&mut self) -> [u8; N] { + let start = self.read_range.next().expect("Read past the limit"); + let end = self.read_range.nth(N - 2).expect("Read past the limit"); + self.inner[start..=end].try_into().expect("Unreachable!") + } + + fn read_inline_or(&mut self, byte: u8) -> usize { + if byte >> 4 != 0 { + // Length is stored inline + (byte >> 4) as usize + } else { + usize::from_ne_bytes(self.read_array()) + } + } + + /// Get the next command from the tape. + /// + /// Will return `None` if exhausted. + pub fn command(&mut self) -> Option { + if let Some(position) = self.read_range.next() { + let byte = self.inner[position]; + match byte & 0xF { + Self::NULL => Some(ToRead::Null), + Self::BOOLEAN => Some(ToRead::Boolean), + Self::INT => Some(ToRead::Int), + Self::LONG => Some(ToRead::Long), + Self::FLOAT => Some(ToRead::Float), + Self::DOUBLE => Some(ToRead::Double), + Self::BYTES => Some(ToRead::Bytes), + Self::STRING => Some(ToRead::String), + Self::ENUM => Some(ToRead::Enum), + Self::FIXED => Some(ToRead::Fixed(self.read_inline_or(byte))), + Self::BLOCK => { + // ToRead::Block + let size = (byte >> 4) as usize; + self.skip(size); + Some(ToRead::Block(self.extract(position + 1, size))) + } + Self::UNION => { + // How many variants are there? + let num_variants = self.read_inline_or(byte); + + // Skip over the union variants while keeping track of their start and end + // so we can easily create the command tape + let start = self.read_range.start; + self.skip(num_variants); + let end = self.read_range.start; + + // Create the command tape from the previously tracked start and end + let mut tape = self.clone(); + tape.read_range.start = start; + tape.read_range.end = end; + + Some(ToRead::Union { + variants: tape, + num_variants, + }) + } + Self::REF => { + let size = self.read_inline_or(byte); + let offset = usize::from_ne_bytes(self.read_array()); + Some(ToRead::Ref(self.extract(offset, size))) + } + Self::SKIP => { + // Read how many commands to skip and skip them + let commands = self.read_inline_or(byte); + self.skip(commands); + + // Return the next command + self.command() + } + _ => unreachable!(), // TODO: There is room here to specialize certain types, like a Union of Null and some other type + } + } else { + None + } + } + + /// Skip `amount` commands. + /// + /// If a command contains subcommands, these will also be skipped. + /// + /// # Returns + /// `None` if it read past the end of the tape + pub(crate) fn skip(&mut self, mut amount: usize) -> Option<()> { + let mut i = 0; + while i < amount { + let position = self.read_range.next()?; + let byte = self.inner[position]; + match byte & 0xF { + CommandTape::BOOLEAN + | CommandTape::INT + | CommandTape::LONG + | CommandTape::FLOAT + | CommandTape::DOUBLE + | CommandTape::BYTES + | CommandTape::STRING + | CommandTape::ENUM + | CommandTape::NULL => {} + CommandTape::FIXED => { + let _size = self.read_inline_or(byte); + } + CommandTape::REF => { + let _size = self.read_inline_or(byte); + let _offset = usize::from_ne_bytes(self.read_array()); + } + CommandTape::UNION | CommandTape::BLOCK | CommandTape::SKIP => { + // These commands can inline other commands, so add them to the skip list + let num_variants = self.read_inline_or(byte); + amount += num_variants; + + // Skip does not count as a command, but we do increment `i` so we compensate + // for that by incrementing the amount + if byte & 0xF == CommandTape::SKIP { + amount += 1; + } + } + _ => unreachable!(), + } + i += 1; + } + Some(()) + } +} + +impl std::fmt::Debug for CommandTape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut c = self.clone(); + + write!(f, "CommandTape: ")?; + let mut list = f.debug_list(); + while let Some(command) = c.command() { + list.entry(&command); + } + list.finish() + } +} + +struct CommandTapeBuilder<'a> { + tape: Vec, + references: HashMap<&'a Name, (usize, usize)>, + names: &'a Names, +} + +impl<'a> CommandTapeBuilder<'a> { + pub fn new(names: &'a Names) -> Self { + Self { + tape: Vec::new(), + references: HashMap::new(), + names, + } + } + + fn add_schema(&mut self, schema: &'a Schema, inline_up_to: usize) -> Result { + match schema { + Schema::Null => { + self.tape.push(CommandTape::NULL); + Ok(1) + } + Schema::Boolean => { + self.tape.push(CommandTape::BOOLEAN); + Ok(1) + } + Schema::Int | Schema::Date | Schema::TimeMillis => { + self.tape.push(CommandTape::INT); + Ok(1) + } + Schema::Long + | Schema::TimeMicros + | Schema::TimestampMillis + | Schema::TimestampMicros + | Schema::TimestampNanos + | Schema::LocalTimestampMillis + | Schema::LocalTimestampMicros + | Schema::LocalTimestampNanos => { + self.tape.push(CommandTape::LONG); + Ok(1) + } + Schema::Float => { + self.tape.push(CommandTape::FLOAT); + Ok(1) + } + Schema::Double => { + self.tape.push(CommandTape::DOUBLE); + Ok(1) + } + Schema::Bytes | Schema::BigDecimal => { + self.tape.push(CommandTape::BYTES); + Ok(1) + } + Schema::String | Schema::Uuid => { + self.tape.push(CommandTape::STRING); + Ok(1) + } + Schema::Array(ArraySchema { items, .. }) => { + let block_offset = self.tape.len(); + self.tape.push(CommandTape::BLOCK); + let commands = self.add_schema(items, 16)?; + self.tape[block_offset] = CommandTape::BLOCK | (commands << 4) as u8; + Ok(1) + } + Schema::Map(MapSchema { types, .. }) => { + let block_offset = self.tape.len(); + self.tape.push(CommandTape::BLOCK); + self.tape.push(CommandTape::STRING); + let commands = self.add_schema(types, 15)?; + self.tape[block_offset] = CommandTape::BLOCK | ((commands + 1) << 4) as u8; + Ok(1) + } + Schema::Union(UnionSchema { schemas, .. }) => { + let schema_len = schemas.len(); + if 0 < schema_len && schema_len <= 0xF { + self.tape.push(CommandTape::UNION | (schema_len << 4) as u8); + } else { + self.tape.push(CommandTape::UNION); + self.tape.extend_from_slice(&schema_len.to_ne_bytes()); + } + for schema in schemas { + self.add_schema(schema, 1)?; + } + Ok(1) + } + Schema::Record(RecordSchema { name, fields, .. }) => { + if let Some(&(offset, commands)) = self.references.get(name) { + self.add_reference(offset, commands); + Ok(1) + } else if fields.is_empty() { + panic!("Record has no fields! {schema:?}"); + } else { + let commands = fields.len(); + if commands > inline_up_to { + // If this record is larger than the amount we're allowed to inline, inject + // a SKIP command. + if commands <= 0xF { + self.tape.push(CommandTape::SKIP | (commands << 4) as u8); + } else { + self.tape.push(CommandTape::SKIP); + self.tape.extend_from_slice(&commands.to_ne_bytes()); + } + } + let offset = self.tape.len(); + self.references.insert(name, (offset, commands)); + for field in fields { + let _commands = self.add_schema(&field.schema, 1)?; + } + if commands > inline_up_to { + // Now refer back to the skip block + self.add_reference(offset, commands); + Ok(1) + } else { + Ok(commands) + } + } + } + Schema::Enum(EnumSchema { name, .. }) => { + let offset = self.tape.len(); + let commands = 1; + self.tape.push(CommandTape::ENUM); + self.references.insert(name, (offset, commands)); + Ok(1) + } + Schema::Fixed(FixedSchema { name, size, .. }) => { + let offset = self.tape.len(); + if 0 < *size && *size <= 0xF { + self.tape.push(CommandTape::FIXED | (*size << 4) as u8); + } else { + self.tape.push(CommandTape::FIXED); + self.tape.extend_from_slice(&size.to_ne_bytes()); + } + self.references.entry(name).or_insert((offset, 1)); + Ok(1) + } + Schema::Decimal(DecimalSchema { inner, .. }) => self.add_schema(inner, inline_up_to), + Schema::Duration => { + self.tape.push(CommandTape::FIXED | 12 << 4); + Ok(1) + } + Schema::Ref { name } => { + if let Some(&(offset, commands)) = self.references.get(name) { + self.add_reference(offset, commands); + Ok(1) + } else if let Some(schema) = self.names.get(name).as_ref() { + self.add_schema(schema, inline_up_to) + } else { + Err(Details::SchemaResolutionError(name.clone()).into()) + } + } + } + } + + fn add_reference(&mut self, offset: usize, commands: usize) { + if commands == 0 { + self.tape.push(CommandTape::NULL); + } else if commands <= 0xF { + self.tape.push(CommandTape::REF | (commands << 4) as u8); + } else { + self.tape.push(CommandTape::REF); + self.tape.extend_from_slice(&commands.to_ne_bytes()); + } + self.tape.extend_from_slice(&offset.to_ne_bytes()); + } + + pub fn build(schema: &Schema, names: &'a Names) -> Result { + let mut builder = Self::new(names); + + builder.add_schema(schema, usize::MAX)?; + + let tape_len = builder.tape.len(); + Ok(CommandTape { + inner: Arc::from(builder.tape), + read_range: 0..tape_len, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn command_tape_simple() { + assert_eq!( + CommandTape::build_from_schema(&Schema::Null, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::NULL] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Boolean, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::BOOLEAN] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Int, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Date, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimeMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::INT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Long, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimeMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::TimestampNanos, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampMillis, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampMicros, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::LocalTimestampNanos, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::LONG] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Float, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::FLOAT] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Double, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::DOUBLE] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Bytes, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::BYTES] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::String, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::STRING] + ); + assert_eq!( + CommandTape::build_from_schema(&Schema::Uuid, &HashMap::new()) + .unwrap() + .inner + .as_ref(), + &[CommandTape::STRING] + ); + } +} diff --git a/avro/src/encode2/datum.rs b/avro/src/encode2/datum.rs new file mode 100644 index 00000000..67a7e384 --- /dev/null +++ b/avro/src/encode2/datum.rs @@ -0,0 +1,80 @@ +use oval::Buffer; + +use crate::encode2::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, + SubStateMachine, +}; + +enum TapeOrFsm { + Tape(Vec), + Fsm(Box), +} + +pub struct DatumStateMachine { + command_tape: CommandTape, + tape_or_fsm: TapeOrFsm, +} + +impl DatumStateMachine { + /// Create a new state machine that reads a datum from the commands. + pub fn new(command_tape: CommandTape) -> Self { + Self::new_with_tape(command_tape, Vec::new()) + } + + /// Create a new state machine that appends to the tape (the tape is returned on completion). + pub fn new_with_tape(command_tape: CommandTape, tape: Vec) -> Self { + Self { + command_tape, + tape_or_fsm: TapeOrFsm::Tape(tape), + } + } +} + +impl StateMachine for DatumStateMachine { + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + // While there's data and commands to process we keep progressing the state machines + while !buffer.data().is_empty() { + match self.tape_or_fsm { + TapeOrFsm::Fsm(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(Box::new(fsm)); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(read) => { + self.tape_or_fsm = TapeOrFsm::Tape(read); + } + }, + TapeOrFsm::Tape(tape) => { + if let Some(command) = self.command_tape.command() { + let fsm = command.into_state_machine(tape); + // This is a duplicate of the TapeOrFsm::Fsm logic, but saves us an allocation + // by doing it immediately. + match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.tape_or_fsm = TapeOrFsm::Fsm(Box::new(fsm)); + return Ok(StateMachineControlFlow::NeedMore(self)); + } + StateMachineControlFlow::Done(read) => { + self.tape_or_fsm = TapeOrFsm::Tape(read); + } + } + } else { + self.tape_or_fsm = TapeOrFsm::Tape(tape); + break; + } + } + } + } + + // Check if we're completely finished or need more data + match (self.tape_or_fsm, self.command_tape.is_finished()) { + (TapeOrFsm::Tape(read), true) => Ok(StateMachineControlFlow::Done(read)), + (tape_or_fsm, _) => { + self.tape_or_fsm = tape_or_fsm; + Ok(StateMachineControlFlow::NeedMore(self)) + } + } + } +} diff --git a/avro/src/encode2/error.rs b/avro/src/encode2/error.rs new file mode 100644 index 00000000..b73c4346 --- /dev/null +++ b/avro/src/encode2/error.rs @@ -0,0 +1,16 @@ +use crate::{Schema, encode2::ItemRead}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum EncodeError { + #[error("Unexpected end of tape while building Value")] + UnexpectedEndOfTape, + #[error( + "Mismatch between tape and schema while building Value: schema {schema}, tape: {item:?}" + )] + TapeSchemaMismatch { schema: Schema, item: ItemRead }, + #[error( + "Mismatch between tape and schema while building Value: Schema::Fixed expected {expected} bytes, but tape had {actual}" + )] + TapeSchemaMismatchFixed { expected: usize, actual: usize }, +} diff --git a/avro/src/encode2/mod.rs b/avro/src/encode2/mod.rs index 8b137891..5a329968 100644 --- a/avro/src/encode2/mod.rs +++ b/avro/src/encode2/mod.rs @@ -1 +1,162 @@ +use crate::{ + Error, +}; +use oval::Buffer; + +pub trait StateMachine: Sized { + type Output: Sized; + + /// Start/continue the state machine. + /// + /// Implementers are not allowed to return until they can't make progress anymore. + fn parse(self, buffer: &mut Buffer) -> StateMachineResult; +} + +/// Indicates whether the state machine has completed or needs to be polled again. +#[must_use] +pub enum StateMachineControlFlow { + /// The state machine needs more data before it can continue. + NeedMore(StateMachine), + /// The state machine is done and the result is returned. + Done(Output), +} + +pub type StateMachineResult = + Result, Error>; + +/// The sub state machine that is currently being driven. +/// +/// The `Int`, `Long`, `Float`, `Double`, and `Enum` state machines don't have state, as +/// they don't consume the buffer if there are not enough bytes. This means that the only +/// thing these state machines are keeping track of is which type we're actually decoding. +pub enum SubStateMachine { + // Null(Vec), + // Bool(Vec), + // Int(Vec), + // Long(Vec), + // Float(Vec), + // Double(Vec), + // Enum(Vec), + // Bytes { + // fsm: BytesStateMachine, + // read: Vec, + // }, + // String { + // fsm: BytesStateMachine, + // read: Vec, + // }, + // Block(BlockStateMachine), + // Object(DatumStateMachine), + // Union(UnionStateMachine), +} + +impl StateMachine for SubStateMachine { + type Output = (); + + fn parse(self, _buffer: &mut Buffer) -> StateMachineResult { + match self { + // SubStateMachine::Null(mut read) => { + // read.push(ItemRead::Null); + // Ok(StateMachineControlFlow::Done(read)) + // } + // SubStateMachine::Bool(mut read) => { + // let mut byte = [0; 1]; + // buffer + // .read_exact(&mut byte) + // .expect("Unreachable! Buffer is not empty"); + // match byte { + // [0] => read.push(ItemRead::Boolean(false)), + // [1] => read.push(ItemRead::Boolean(true)), + // [byte] => return Err(Details::BoolValue(byte).into()), + // } + // Ok(StateMachineControlFlow::Done(read)) + // } + // SubStateMachine::Int(mut read) => { + // let Some(n) = decode_zigzag_buffer(buffer)? else { + // // Not enough data left in the buffer + // return Ok(StateMachineControlFlow::NeedMore(Self::Int(read))); + // }; + // let n = i32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + // read.push(ItemRead::Int(n)); + // Ok(StateMachineControlFlow::Done(read)) + // } + // SubStateMachine::Long(mut read) => { + // let Some(n) = decode_zigzag_buffer(buffer)? else { + // // Not enough data left in the buffer + // return Ok(StateMachineControlFlow::NeedMore(Self::Long(read))); + // }; + // read.push(ItemRead::Long(n)); + // Ok(StateMachineControlFlow::Done(read)) + // } + // SubStateMachine::Float(mut read) => { + // let Some(bytes) = buffer.data().first_chunk().copied() else { + // // Not enough data left in the buffer + // return Ok(StateMachineControlFlow::NeedMore(Self::Float(read))); + // }; + // buffer.consume(4); + // read.push(ItemRead::Float(f32::from_le_bytes(bytes))); + // Ok(StateMachineControlFlow::Done(read)) + // } + // SubStateMachine::Double(mut read) => { + // let Some(bytes) = buffer.data().first_chunk().copied() else { + // // Not enough data left in the buffer + // return Ok(StateMachineControlFlow::NeedMore(Self::Double(read))); + // }; + // buffer.consume(8); + // read.push(ItemRead::Double(f64::from_le_bytes(bytes))); + // Ok(StateMachineControlFlow::Done(read)) + // } + // SubStateMachine::Enum(mut read) => { + // let Some(n) = decode_zigzag_buffer(buffer)? else { + // // Not enough data left in the buffer + // return Ok(StateMachineControlFlow::NeedMore(Self::Enum(read))); + // }; + // // TODO: Wrong error + // let n = u32::try_from(n).map_err(|e| Details::ZagI32(e, n))?; + // read.push(ItemRead::Enum(n)); + // Ok(StateMachineControlFlow::Done(read)) + // } + // SubStateMachine::Bytes { fsm, mut read } => match fsm.parse(buffer)? { + // StateMachineControlFlow::NeedMore(fsm) => { + // Ok(StateMachineControlFlow::NeedMore(Self::Bytes { fsm, read })) + // } + // StateMachineControlFlow::Done(bytes) => { + // read.push(ItemRead::Bytes(bytes)); + // Ok(StateMachineControlFlow::Done(read)) + // } + // }, + // SubStateMachine::String { fsm, mut read } => match fsm.parse(buffer)? { + // StateMachineControlFlow::NeedMore(fsm) => { + // Ok(StateMachineControlFlow::NeedMore(Self::String { + // fsm, + // read, + // })) + // } + // StateMachineControlFlow::Done(bytes) => { + // let string = String::from_utf8(bytes).map_err(Details::ConvertToUtf8)?; + // read.push(ItemRead::String(string)); + // Ok(StateMachineControlFlow::Done(read)) + // } + // }, + // SubStateMachine::Block(fsm) => match fsm.parse(buffer)? { + // StateMachineControlFlow::NeedMore(fsm) => { + // Ok(StateMachineControlFlow::NeedMore(Self::Block(fsm))) + // } + // StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + // }, + // SubStateMachine::Union(fsm) => match fsm.parse(buffer)? { + // StateMachineControlFlow::NeedMore(fsm) => { + // Ok(StateMachineControlFlow::NeedMore(Self::Union(fsm))) + // } + // StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + // }, + // SubStateMachine::Object(fsm) => match fsm.parse(buffer)? { + // StateMachineControlFlow::NeedMore(fsm) => { + // Ok(StateMachineControlFlow::NeedMore(Self::Object(fsm))) + // } + // StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + // }, + } + } +} diff --git a/avro/src/encode2/object_container_file.rs b/avro/src/encode2/object_container_file.rs new file mode 100644 index 00000000..504e3258 --- /dev/null +++ b/avro/src/encode2/object_container_file.rs @@ -0,0 +1,318 @@ +use crate::{ + Codec, Error, Schema, + encode2::{ + CommandTape, ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, + codec::CodecStateMachine, datum::DatumStateMachine, decode_zigzag_buffer, + }, + error::Details, + schema::{Names, ResolvedSchema, resolve_names, resolve_names_with_schemata}, +}; +use log::warn; +use oval::Buffer; +use serde_json::Value; +use std::{collections::HashMap, io::Read, str::FromStr, sync::Arc}; + +// TODO: Dynamically/const construct this, this one works only on 64-bit LE +/// The tape corresponding to [`HEADER_JSON`]. +/// +/// ```json +/// { +/// "type": "record", +/// "name": "org.apache.avro.file.HeaderNoMagic", +/// "fields": [ +/// {"name": "meta", "type": {"type": "map", "values": "bytes"}}, +/// {"name": "sync", "type": {"type": "fixed", "name": "Sync", "size": 16}} +/// ] +/// } +/// ``` +#[rustfmt::skip] +const HEADER_TAPE: &[u8] = &[ + CommandTape::BLOCK | 2 << 4, // Starts with a map + CommandTape::STRING, // The keys are strings + CommandTape::BYTES, // The values are bytes + CommandTape::FIXED, // After the map there is a Fixed amount of bytes + 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // The amount of bytes is 0x0F +]; +#[cfg(test)] +const HEADER_JSON: &str = r#"{"type": "record","name": "org.apache.avro.file.HeaderNoMagic","fields": [{"name": "meta", "type": {"type": "map", "values": "bytes"}},{"name": "sync", "type": {"type": "fixed", "name": "Sync", "size": 16}}]}"#; + +/// The header as read from an Object Container file format. +pub struct ObjectContainerFileHeader { + /// The schema used to write the file. + pub schema: Schema, + pub names: Names, + /// The compression used. + pub codec: Codec, + /// The sync marker used between blocks + pub sync: [u8; 16], + /// User metadata in the header + pub metadata: HashMap>, +} + +impl ObjectContainerFileHeader { + pub fn command_tape() -> CommandTape { + CommandTape::new(Arc::from(HEADER_TAPE)) + } + + /// Create the header from an output tape. + /// + /// # Panics + /// Will panic if the tape was not produced from [`Self::command_tape()`]. + pub fn from_tape(mut tape: Vec, mut schemata: Vec<&Schema>) -> Result { + // Vec::remove(0) is an O(N) operation, so we use `drain` to read from front to back + let mut tape = tape.drain(..); + + let mut schema = None; + let mut codec = None; + let mut found_compression_level = false; + let mut metadata = HashMap::new(); + let mut names = HashMap::new(); + + while let Some(ItemRead::Block(items_left)) = tape.next() { + if items_left == 0 { + // Got to the end of the map + break; + } + for _ in 0..items_left { + let Some(ItemRead::String(key)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + let Some(ItemRead::Bytes(value)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + + match key.as_ref() { + "avro.schema" => { + if schema.is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + let json: Value = + serde_json::from_slice(&value).map_err(Details::ParseSchemaJson)?; + + if !schemata.is_empty() { + // TODO: Make parse_with_names accept NamesRef + let schemata = std::mem::take(&mut schemata); + resolve_names_with_schemata(&schemata, &mut names, &None)?; + + // TODO: Maybe we can not do this, and just past &names to Schema::parse_with_names + let rs = ResolvedSchema::try_from(schemata)?; + let names: Names = rs + .get_names() + .iter() + .map(|(name, &schema)| (name.clone(), schema.clone())) + .collect(); + + let parsed_schema = Schema::parse_with_names(&json, names)?; + schema.replace(parsed_schema); + } else { + let parsed_schema = Schema::parse(&json)?; + resolve_names(&parsed_schema, &mut names, &None)?; + schema.replace(parsed_schema); + } + } + "avro.codec" => { + let string = String::from_utf8(value).map_err(Details::ConvertToUtf8)?; + let parsed_codec = Codec::from_str(&string) + .map_err(|_| Details::CodecNotSupported(string))?; + if codec.replace(parsed_codec).is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + } + "avro.codec.compression_level" => { + // Compression level is not useful for decoding + if found_compression_level { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + found_compression_level = true; + } + _ => { + if key.starts_with("avro.") { + warn!("Ignoring unknown metadata key: {key}"); + } + if metadata.insert(key, value).is_some() { + // Duplicate key + return Err(Details::GetHeaderMetadata.into()); + } + } + } + } + } + let Some(schema) = schema else { + return Err(Details::GetHeaderMetadata.into()); + }; + let codec = codec.unwrap_or(Codec::Null); + let Some(ItemRead::Bytes(raw_sync)) = tape.next() else { + panic!("The input does not correspond to the command tape"); + }; + let sync = raw_sync + .as_slice() + .try_into() + .expect("The input does not correspond to the command tape"); + Ok(ObjectContainerFileHeader { + schema, + names, + codec, + sync, + metadata, + }) + } +} + +/// A state machine for parsing the header of the Object Container file format. +/// +/// After finishing this state machine the body can be read with [`ObjectContainerFileBodyStateMachine`]. +pub struct ObjectContainerFileHeaderStateMachine<'a> { + /// The actual state machine used to parse the header. + /// + /// This doesn't actually need to be an [`Option`] as it's constructed in [`Self::new`]. However, + /// as [`StateMachine::parse`] takes `self` we need it in an `Option` so we can do [`Option::take`]. + fsm: Option, + read_magic: bool, + schemata: Vec<&'a Schema>, +} + +impl<'a> ObjectContainerFileHeaderStateMachine<'a> { + pub fn new(schemata: Vec<&'a Schema>) -> Self { + let commands = CommandTape::new(Arc::from(HEADER_TAPE)); + Self { + fsm: Some(DatumStateMachine::new(commands)), + read_magic: false, + schemata, + } + } +} + +impl StateMachine for ObjectContainerFileHeaderStateMachine<'_> { + type Output = ObjectContainerFileHeader; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + while !self.read_magic { + if buffer.available_data() < 4 { + return Ok(StateMachineControlFlow::NeedMore(self)); + } + if buffer.data()[0..4] != [b'O', b'b', b'j', 1] { + return Err(Details::HeaderMagic.into()); + } + buffer.consume(4); + self.read_magic = true; + } + match self.fsm.take().expect("Unreachable!").parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + let _ = self.fsm.insert(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(tape) => Ok(StateMachineControlFlow::Done( + ObjectContainerFileHeader::from_tape(tape, self.schemata)?, + )), + } + } +} + +pub struct ObjectContainerFileBodyStateMachine { + fsm: Option>, + tape: CommandTape, + sync: [u8; 16], + left_in_block: usize, + need_to_read_block_byte_size: bool, + need_to_read_sync: bool, +} + +impl ObjectContainerFileBodyStateMachine { + pub fn new(tape: CommandTape, sync: [u8; 16], codec: Codec) -> Self { + Self { + fsm: Some(CodecStateMachine::new( + DatumStateMachine::new(tape.clone()), + codec, + )), + tape, + sync, + left_in_block: 0, + need_to_read_block_byte_size: false, + need_to_read_sync: false, + } + } +} + +impl StateMachine for ObjectContainerFileBodyStateMachine { + type Output = Option<(Vec, Self)>; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + if self.left_in_block == 0 { + if self.need_to_read_sync { + if buffer.available_data() < 16 { + return Ok(StateMachineControlFlow::NeedMore(self)); + } + let mut sync = [0; 16]; + assert_eq!( + buffer.read(&mut sync).expect("Unreachable!"), + 16, + "Did not read enough data!" + ); + if sync != self.sync { + return Err(Details::GetBlockMarker.into()); + } + self.need_to_read_sync = false; + } + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let abs_block = block.unsigned_abs(); + let abs_block = + usize::try_from(abs_block).map_err(|e| Details::ConvertU64ToUsize(e, abs_block))?; + if abs_block == 0 { + // Done parsing the array + return Ok(StateMachineControlFlow::Done(None)); + } + self.need_to_read_block_byte_size = true; + // This will only be done after this block is finished + self.need_to_read_sync = true; + self.left_in_block = abs_block; + } + if self.need_to_read_block_byte_size { + let Some(block) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + // Make sure the value is sane + let _size = usize::try_from(block).map_err(|e| Details::ConvertI64ToUsize(e, block))?; + self.need_to_read_block_byte_size = false; + } + + match self.fsm.take().expect("Unreachable!").parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.fsm.replace(fsm); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done((result, mut codec)) => { + codec.reset(DatumStateMachine::new(self.tape.clone())); + self.fsm.replace(codec); + self.left_in_block -= 1; + Ok(StateMachineControlFlow::Done(Some((result, self)))) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Arc}; + + use crate::{ + Schema, + encode2::{ + commands::CommandTape, + object_container_file::{HEADER_JSON, HEADER_TAPE}, + }, + }; + + #[test] + pub fn header_tape() { + let schema = Schema::parse_str(HEADER_JSON).unwrap(); + let tape = CommandTape::build_from_schema(&schema, &HashMap::new()).unwrap(); + assert_eq!(tape, CommandTape::new(Arc::from(HEADER_TAPE))); + } +} diff --git a/avro/src/encode2/union.rs b/avro/src/encode2/union.rs new file mode 100644 index 00000000..cc3fa4da --- /dev/null +++ b/avro/src/encode2/union.rs @@ -0,0 +1,78 @@ +use crate::{ + encode2::{ + ItemRead, StateMachine, StateMachineControlFlow, StateMachineResult, SubStateMachine, + commands::CommandTape, decode_zigzag_buffer, + }, + error::Details, +}; +use oval::Buffer; + +enum VariantsOrFsm { + Variants { + variants: CommandTape, + read: Vec, + }, + Fsm(Box), +} + +pub struct UnionStateMachine { + variants_or_fsm: VariantsOrFsm, + num_variants: usize, +} + +impl UnionStateMachine { + pub fn new_with_tape(variants: CommandTape, num_variants: usize, read: Vec) -> Self { + Self { + variants_or_fsm: VariantsOrFsm::Variants { variants, read }, + num_variants, + } + } +} + +impl StateMachine for UnionStateMachine { + type Output = Vec; + + fn parse(mut self, buffer: &mut Buffer) -> StateMachineResult { + match self.variants_or_fsm { + VariantsOrFsm::Variants { + mut variants, + mut read, + } => { + let Some(index) = decode_zigzag_buffer(buffer)? else { + // Not enough data left in the buffer + self.variants_or_fsm = VariantsOrFsm::Variants { variants, read }; + return Ok(StateMachineControlFlow::NeedMore(self)); + }; + let option = + usize::try_from(index).map_err(|e| Details::ConvertI64ToUsize(e, index))?; + + variants.skip(option).ok_or(Details::GetUnionVariant { + index, + num_variants: self.num_variants, + })?; + + let variant = variants.command().ok_or(Details::GetUnionVariant { + index, + num_variants: self.num_variants, + })?; + + read.push(ItemRead::Union(u32::try_from(option).unwrap())); + + match variant.into_state_machine(read).parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.variants_or_fsm = VariantsOrFsm::Fsm(Box::new(fsm)); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + } + } + VariantsOrFsm::Fsm(fsm) => match fsm.parse(buffer)? { + StateMachineControlFlow::NeedMore(fsm) => { + self.variants_or_fsm = VariantsOrFsm::Fsm(Box::new(fsm)); + Ok(StateMachineControlFlow::NeedMore(self)) + } + StateMachineControlFlow::Done(read) => Ok(StateMachineControlFlow::Done(read)), + }, + } + } +} diff --git a/avro/src/lib.rs b/avro/src/lib.rs index 35ea55d1..0b93b863 100644 --- a/avro/src/lib.rs +++ b/avro/src/lib.rs @@ -965,6 +965,7 @@ pub mod types; pub mod util; pub mod validator; pub mod writer; +pub mod decode2; pub use crate::{ bigdecimal::BigDecimal, diff --git a/avro/src/reader/mod.rs b/avro/src/reader/mod.rs index 0ac75c5c..044c9993 100644 --- a/avro/src/reader/mod.rs +++ b/avro/src/reader/mod.rs @@ -14,9 +14,6 @@ use serde::de::DeserializeOwned; use std::{io::Read, marker::PhantomData}; pub use sync::*; -// This is for API compatibility with previous versions -pub use sync::*; - /// Reads the marker bytes from Avro bytes generated earlier by a [`Writer`]. /// /// [`Writer`]: crate::Writer diff --git a/avro/src/schema.rs b/avro/src/schema.rs index 5839578b..a6a24629 100644 --- a/avro/src/schema.rs +++ b/avro/src/schema.rs @@ -110,8 +110,8 @@ pub enum Schema { /// Logical type which represents `Decimal` values without predefined scale. /// The underlying type is serialized and deserialized as `Schema::Bytes` BigDecimal, - /// A universally unique identifier, annotating a string. - Uuid, + /// A universally unique identifier. + Uuid(UuidSchema), /// Logical type which represents the number of days since the unix epoch. /// Serialization format is `Schema::Int`. Date, @@ -566,6 +566,7 @@ pub(crate) fn resolve_names( enclosing_namespace: &Namespace, ) -> AvroResult<()> { match schema { + // TODO: Shouldn't Fixed (and thus Uuid/Decimal) also be in here? Schema::Array(schema) => resolve_names(&schema.items, names, enclosing_namespace), Schema::Map(schema) => resolve_names(&schema.types, names, enclosing_namespace), Schema::Union(UnionSchema { schemas, .. }) => { @@ -612,7 +613,7 @@ pub(crate) fn resolve_names( } pub(crate) fn resolve_names_with_schemata( - schemata: &Vec<&Schema>, + schemata: &[&Schema], names: &mut Names, enclosing_namespace: &Namespace, ) -> AvroResult<()> { @@ -883,7 +884,7 @@ impl FixedSchema { } } -/// A description of a Union schema. +/// A description of a Decimal schema. /// /// `scale` defaults to 0 and is an integer greater than or equal to 0 and `precision` is an /// integer greater than 0. @@ -995,6 +996,17 @@ impl PartialEq for UnionSchema { } } +/// The schema of the UUID. +#[derive(Debug, Clone)] +pub enum UuidSchema { + /// The UUID is embedded as a human-readable string. + String, + /// The UUID is in the 16-byte binary form. + Fixed(FixedSchema), + // TODO: I think avro-rs also supports a Bytes variant, but not sure + // TODO: Add tests for Fixed +} + type DecimalMetadata = usize; pub(crate) type Precision = DecimalMetadata; pub(crate) type Scale = DecimalMetadata; @@ -1571,8 +1583,8 @@ impl Parser { parse_as_native_complex(complex, self, enclosing_namespace)?, &[SchemaKind::String, SchemaKind::Fixed], |schema| match schema { - Schema::String => Ok(Schema::Uuid), - Schema::Fixed(FixedSchema { size: 16, .. }) => Ok(Schema::Uuid), + Schema::String => Ok(Schema::Uuid(UuidSchema::String)), + Schema::Fixed(fixed @ FixedSchema { size: 16, .. }) => Ok(Schema::Uuid(UuidSchema::Fixed(fixed))), Schema::Fixed(FixedSchema { size, .. }) => { warn!( "Ignoring uuid logical type for a Fixed schema because its size ({size:?}) is not 16! Schema: {schema:?}" @@ -2199,8 +2211,16 @@ impl Serialize for Schema { map.serialize_entry("logicalType", "big-decimal")?; map.end() } - Schema::Uuid => { + Schema::Uuid(ref uuid_schema) => { let mut map = serializer.serialize_map(None)?; + match uuid_schema { + UuidSchema::String => { + map.serialize_entry("type", "string")?; + } + UuidSchema::Fixed(fixed_schema) => { + map = fixed_schema.serialize_to_map::(map)?; + } + } map.serialize_entry("type", "string")?; map.serialize_entry("logicalType", "uuid")?; map.end() @@ -2539,7 +2559,8 @@ pub mod derive { impl_schema!(f32, Schema::Float); impl_schema!(f64, Schema::Double); impl_schema!(String, Schema::String); - impl_schema!(uuid::Uuid, Schema::Uuid); + // TODO: Maybe change to UuidSchema::Fixed + impl_schema!(uuid::Uuid, Schema::Uuid(UuidSchema::String)); impl_schema!(core::time::Duration, Schema::Duration); impl AvroSchemaComponent for Vec @@ -6698,7 +6719,7 @@ mod tests { "logicalType": "uuid" }); let parse_result = Schema::parse(&schema)?; - assert_eq!(parse_result, Schema::Uuid); + assert_eq!(parse_result, Schema::Uuid(UuidSchema::String)); Ok(()) } @@ -6713,7 +6734,7 @@ mod tests { "logicalType": "uuid" }); let parse_result = Schema::parse(&schema)?; - assert_eq!(parse_result, Schema::Uuid); + assert!(matches!(parse_result, Schema::Uuid(UuidSchema::Fixed(_)))); assert_not_logged( r#"Ignoring uuid logical type for a Fixed schema because its size (6) is not 16! Schema: Fixed(FixedSchema { name: Name { name: "FixedUUID", namespace: None }, aliases: None, doc: None, size: 6, attributes: {"logicalType": String("uuid")} })"#, ); diff --git a/avro/src/schema_compatibility.rs b/avro/src/schema_compatibility.rs index be75b911..e08045e2 100644 --- a/avro/src/schema_compatibility.rs +++ b/avro/src/schema_compatibility.rs @@ -529,6 +529,7 @@ mod tests { }; use apache_avro_test_helper::TestResult; use rstest::*; + use crate::schema::UuidSchema; fn int_array_schema() -> Schema { Schema::parse_str(r#"{"type":"array", "items":"int"}"#).unwrap() @@ -1017,8 +1018,9 @@ mod tests { (Schema::String, Schema::Bytes), (Schema::Bytes, Schema::String), // logical types - (Schema::Uuid, Schema::Uuid), - (Schema::Uuid, Schema::String), + // TODO: Add UuidSchema::Fixed + (Schema::Uuid(UuidSchema::String), Schema::Uuid(UuidSchema::String)), + (Schema::Uuid(UuidSchema::String), Schema::String), (Schema::Date, Schema::Int), (Schema::TimeMillis, Schema::Int), (Schema::TimeMicros, Schema::Long), @@ -1028,7 +1030,7 @@ mod tests { (Schema::LocalTimestampMillis, Schema::Long), (Schema::LocalTimestampMicros, Schema::Long), (Schema::LocalTimestampNanos, Schema::Long), - (Schema::String, Schema::Uuid), + (Schema::String, Schema::Uuid(UuidSchema::String)), (Schema::Int, Schema::Date), (Schema::Int, Schema::TimeMillis), (Schema::Long, Schema::TimeMicros), diff --git a/avro/src/schema_equality.rs b/avro/src/schema_equality.rs index 1097594c..51ceb980 100644 --- a/avro/src/schema_equality.rs +++ b/avro/src/schema_equality.rs @@ -24,6 +24,7 @@ use crate::{ }; use log::{debug, error}; use std::{fmt::Debug, sync::OnceLock}; +use crate::schema::UuidSchema; /// A trait that compares two schemata for equality. /// To register a custom one use [set_schemata_equality_comparator]. @@ -79,7 +80,6 @@ impl SchemataEq for StructFieldEq { compare_primitive!(Double); compare_primitive!(Bytes); compare_primitive!(String); - compare_primitive!(Uuid); compare_primitive!(BigDecimal); compare_primitive!(Date); compare_primitive!(Duration); @@ -98,6 +98,8 @@ impl SchemataEq for StructFieldEq { return false; } + // TODO: These nested if-let statements can be improved with match + if let Schema::Record(RecordSchema { fields: fields_one, .. }) = schema_one @@ -126,6 +128,20 @@ impl SchemataEq for StructFieldEq { return false; } + if let Schema::Uuid(UuidSchema::String) = schema_one { + if let Schema::Uuid(UuidSchema::String) = schema_two { + return true; + } + return false; + } + + if let Schema::Uuid(UuidSchema::Fixed(FixedSchema { size: size_one, .. })) = schema_one { + if let Schema::Uuid(UuidSchema::Fixed(FixedSchema { size: size_two, .. })) = schema_two { + return size_one == size_two; + } + return false; + } + if let Schema::Fixed(FixedSchema { size: size_one, .. }) = schema_one { if let Schema::Fixed(FixedSchema { size: size_two, .. }) = schema_two { return size_one == size_two; @@ -283,7 +299,6 @@ mod tests { test_primitives!(Double); test_primitives!(Bytes); test_primitives!(String); - test_primitives!(Uuid); test_primitives!(BigDecimal); test_primitives!(Date); test_primitives!(Duration); @@ -450,6 +465,61 @@ mod tests { assert_eq!(specification_eq_res, struct_field_eq_res); } + #[test] + fn test_avro_3939_compare_uuid_schemata() { + let schema_one = Schema::Uuid(UuidSchema::Fixed(FixedSchema { + name: Name::from("fixed"), + doc: None, + size: 16, + default: None, + aliases: None, + attributes: BTreeMap::new(), + })); + assert!(!SPECIFICATION_EQ.compare(&schema_one, &Schema::Boolean)); + assert!(!STRUCT_FIELD_EQ.compare(&schema_one, &Schema::Boolean)); + + let schema_two = Schema::Fixed(FixedSchema { + name: Name::from("fixed"), + doc: None, + size: 16, + default: None, + aliases: None, + attributes: BTreeMap::new(), + }); + + let specification_eq_res = SPECIFICATION_EQ.compare(&schema_one, &schema_two); + let struct_field_eq_res = STRUCT_FIELD_EQ.compare(&schema_one, &schema_two); + assert!( + specification_eq_res, + "SpecificationEq: Equality of two Schema::Uuid(Fixed) failed!" + ); + assert!( + struct_field_eq_res, + "StructFieldEq: Equality of two Schema::Uuid(Fixed) failed!" + ); + assert_eq!(specification_eq_res, struct_field_eq_res); + + let schema_three = Schema::Uuid(UuidSchema::String); + assert!(!SPECIFICATION_EQ.compare(&schema_three, &Schema::Boolean)); + assert!(!STRUCT_FIELD_EQ.compare(&schema_three, &Schema::Boolean)); + assert!(!SPECIFICATION_EQ.compare(&schema_three, &schema_one)); + assert!(!STRUCT_FIELD_EQ.compare(&schema_three, &schema_one)); + + let schema_four = Schema::Uuid(UuidSchema::String); + + let specification_eq_res = SPECIFICATION_EQ.compare(&schema_three, &schema_four); + let struct_field_eq_res = STRUCT_FIELD_EQ.compare(&schema_three, &schema_four); + assert!( + specification_eq_res, + "SpecificationEq: Equality of two Schema::Uuid(String) failed!" + ); + assert!( + struct_field_eq_res, + "StructFieldEq: Equality of two Schema::Uuid(String) failed!" + ); + assert_eq!(specification_eq_res, struct_field_eq_res); + } + #[test] fn test_avro_3939_compare_enum_schemata() { let schema_one = Schema::Enum(EnumSchema { diff --git a/avro/src/ser_schema.rs b/avro/src/ser_schema.rs index 7d0ed3e5..07b06753 100644 --- a/avro/src/ser_schema.rs +++ b/avro/src/ser_schema.rs @@ -27,6 +27,7 @@ use crate::{ use bigdecimal::BigDecimal; use serde::{Serialize, ser}; use std::{borrow::Cow, io::Write, str::FromStr}; +use crate::schema::UuidSchema; const COLLECTION_SERIALIZER_ITEM_LIMIT: usize = 1024; const COLLECTION_SERIALIZER_DEFAULT_INIT_ITEM_CAPACITY: usize = 32; @@ -885,7 +886,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { }; match schema { - Schema::String | Schema::Bytes | Schema::Uuid => self.write_bytes(value.as_bytes()), + Schema::String | Schema::Bytes | Schema::Uuid(UuidSchema::String) => self.write_bytes(value.as_bytes()), Schema::BigDecimal => { // If we get a string for a `BigDecimal` type, expect a display string representation, such as "12.75" let decimal_val = @@ -893,7 +894,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { let decimal_bytes = big_decimal_as_bytes(&decimal_val)?; self.write_bytes(decimal_bytes.as_slice()) } - Schema::Fixed(fixed_schema) => { + Schema::Uuid(UuidSchema::Fixed(fixed_schema)) | Schema::Fixed(fixed_schema) => { if value.len() == fixed_schema.size { self.writer .write(value.as_bytes()) @@ -915,7 +916,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { match variant_schema { Schema::String | Schema::Bytes - | Schema::Uuid + | Schema::Uuid(_) | Schema::Fixed(_) | Schema::Ref { name: _ } => { encode_int(i as i32, &mut *self.writer)?; @@ -954,10 +955,10 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { }; match schema { - Schema::String | Schema::Bytes | Schema::Uuid | Schema::BigDecimal => { + Schema::String | Schema::Bytes | Schema::Uuid(UuidSchema::String) | Schema::BigDecimal => { self.write_bytes(value) } - Schema::Fixed(fixed_schema) => { + Schema::Uuid(UuidSchema::Fixed(fixed_schema)) | Schema::Fixed(fixed_schema) => { if value.len() == fixed_schema.size { self.writer .write(value) @@ -1017,7 +1018,7 @@ impl<'s, W: Write> SchemaAwareWriteSerializer<'s, W> { match variant_schema { Schema::String | Schema::Bytes - | Schema::Uuid + | Schema::Uuid(_) | Schema::BigDecimal | Schema::Fixed(_) | Schema::Duration diff --git a/avro/src/types.rs b/avro/src/types.rs index 12e5909b..4860b019 100644 --- a/avro/src/types.rs +++ b/avro/src/types.rs @@ -38,6 +38,7 @@ use std::{ str::FromStr, }; use uuid::Uuid; +use crate::schema::UuidSchema; /// Compute the maximum decimal value precision of a byte array of length `len` could hold. fn max_prec_for_len(len: usize) -> Result { @@ -441,14 +442,14 @@ impl Value { (&Value::Decimal(_), &Schema::Decimal { .. }) => None, (&Value::BigDecimal(_), &Schema::BigDecimal) => None, (&Value::Duration(_), &Schema::Duration) => None, - (&Value::Uuid(_), &Schema::Uuid) => None, + (&Value::Uuid(_), &Schema::Uuid(_)) => None, (&Value::Float(_), &Schema::Float) => None, (&Value::Float(_), &Schema::Double) => None, (&Value::Double(_), &Schema::Double) => None, (&Value::Bytes(_), &Schema::Bytes) => None, (&Value::Bytes(_), &Schema::Decimal { .. }) => None, (&Value::String(_), &Schema::String) => None, - (&Value::String(_), &Schema::Uuid) => None, + (&Value::String(_), &Schema::Uuid(UuidSchema::String)) => None, (&Value::Fixed(n, _), &Schema::Fixed(FixedSchema { size, .. })) => { if n != size { Some(format!( @@ -478,6 +479,15 @@ impl Value { None } } + (&Value::Fixed(n, _), &Schema::Uuid(UuidSchema::Fixed(_))) => { + if n != 16 { + Some(format!( + "The value's size ('{n}') must be exactly 16 to be a UUID" + )) + } else { + None + } + } // TODO: check precision against n (&Value::Fixed(_n, _), &Schema::Decimal { .. }) => None, (Value::String(s), Schema::Enum(EnumSchema { symbols, .. })) => { @@ -706,7 +716,7 @@ impl Value { Schema::LocalTimestampMicros => self.resolve_local_timestamp_micros(), Schema::LocalTimestampNanos => self.resolve_local_timestamp_nanos(), Schema::Duration => self.resolve_duration(), - Schema::Uuid => self.resolve_uuid(), + Schema::Uuid(_) => self.resolve_uuid(), } } @@ -716,6 +726,9 @@ impl Value { Value::String(ref string) => { Value::Uuid(Uuid::from_str(string).map_err(Details::ConvertStrToUuid)?) } + Value::Fixed(16, ref bytes) => { + Value::Uuid(Uuid::from_slice(bytes).unwrap_or_else(|_| unreachable!())) + } other => return Err(Details::GetUuid(other).into()), }) } @@ -1871,7 +1884,7 @@ Field with name '"b"' is not a member of the map items"#, #[test] fn resolve_uuid() -> TestResult { let value = Value::Uuid(Uuid::parse_str("1481531d-ccc9-46d9-a56f-5b67459c0537")?); - assert!(value.clone().resolve(&Schema::Uuid).is_ok()); + assert!(value.clone().resolve(&Schema::Uuid(UuidSchema::String)).is_ok()); assert!(value.resolve(&Schema::TimestampMicros).is_err()); Ok(()) diff --git a/avro/src/util.rs b/avro/src/util.rs index d0c3dbf8..4303ba89 100644 --- a/avro/src/util.rs +++ b/avro/src/util.rs @@ -20,6 +20,7 @@ use crate::{AvroResult, Error, error::Details, schema::Documentation}; use serde_json::{Map, Value}; use std::{io::Write, sync::OnceLock}; +use oval::Buffer; /// Maximum number of bytes that can be allocated when decoding /// Avro-encoded values. This is a protection against ill-formed @@ -76,10 +77,41 @@ pub(crate) fn zig_i32(n: i32, buffer: W) -> AvroResult { } pub(crate) fn zig_i64(n: i64, writer: W) -> AvroResult { - encode_variable(((n << 1) ^ (n >> 63)) as u64, writer) + encode_variable(n, writer) } -fn encode_variable(mut z: u64, mut writer: W) -> AvroResult { +/// Zigzag encode an integer. +/// +/// Will only consume the buffer if the entire number could be written. This needs at most 10 bytes. +/// +/// # Returns +/// If there was not enough room for the number in the buffer it will return `None`, otherwise +/// `Some` is returned. +pub(crate) fn encode_variable_buffer(n: impl Into, buffer: &mut Buffer) -> Option<()> { + let n = n.into().to_le() as u64; + let mut z = (n << 1) ^ (n >> 63); + let mut i: usize = 0; + loop { + // Get a mutable reference to location i, returning if that is past the buffer + let mut buffer_i = buffer.space().get_mut(i)?; + if z <= 0x7F { + *buffer_i = (z & 0x7F) as u8; + i += 1; + break; + } else { + *buffer_i = (0x80 | (z & 0x7F)) as u8; + i += 1; + z >>= 7; + } + } + // Only consume the buffer if the whole number has been written + buffer.consume(i); + Some(()) +} + +fn encode_variable(n: i64, mut writer: W) -> AvroResult { + let n = n.to_le() as u64; + let mut z = (n << 1) ^ (n >> 63); let mut buffer = [0u8; 10]; let mut i: usize = 0; loop {