Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions parquet-variant/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
// specific language governing permissions and limitations
// under the License.
use crate::decoder::{VariantBasicType, VariantPrimitiveType};
use crate::Variant;
use crate::{ShortString, Variant};
use std::collections::HashMap;

const BASIC_TYPE_BITS: u8 = 2;
const MAX_SHORT_STRING_SIZE: usize = 0x3F;
const UNIX_EPOCH_DATE: chrono::NaiveDate = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();

fn primitive_header(primitive_type: VariantPrimitiveType) -> u8 {
Expand Down Expand Up @@ -114,11 +113,11 @@ fn make_room_for_header(buffer: &mut Vec<u8>, start_pos: usize, header_size: usi
/// };
/// assert_eq!(
/// variant_object.field_by_name("first_name").unwrap(),
/// Some(Variant::ShortString("Jiaying"))
/// Some(Variant::from("Jiaying"))
/// );
/// assert_eq!(
/// variant_object.field_by_name("last_name").unwrap(),
/// Some(Variant::ShortString("Li"))
/// Some(Variant::from("Li"))
/// );
/// ```
///
Expand Down Expand Up @@ -281,17 +280,18 @@ impl VariantBuilder {
self.buffer.extend_from_slice(value);
}

fn append_short_string(&mut self, value: ShortString) {
let inner = value.0;
self.buffer.push(short_string_header(inner.len()));
self.buffer.extend_from_slice(inner.as_bytes());
}

fn append_string(&mut self, value: &str) {
if value.len() <= MAX_SHORT_STRING_SIZE {
self.buffer.push(short_string_header(value.len()));
self.buffer.extend_from_slice(value.as_bytes());
} else {
self.buffer
.push(primitive_header(VariantPrimitiveType::String));
self.buffer
.extend_from_slice(&(value.len() as u32).to_le_bytes());
self.buffer.extend_from_slice(value.as_bytes());
}
self.buffer
.push(primitive_header(VariantPrimitiveType::String));
self.buffer
.extend_from_slice(&(value.len() as u32).to_le_bytes());
self.buffer.extend_from_slice(value.as_bytes());
}

/// Add key to dictionary, return its ID
Expand Down Expand Up @@ -390,7 +390,8 @@ impl VariantBuilder {
Variant::Float(v) => self.append_float(v),
Variant::Double(v) => self.append_double(v),
Variant::Binary(v) => self.append_binary(v),
Variant::String(s) | Variant::ShortString(s) => self.append_string(s),
Variant::String(s) => self.append_string(s),
Variant::ShortString(s) => self.append_short_string(s),
Variant::Object(_) | Variant::List(_) => {
unreachable!("Object and List variants cannot be created through Into<Variant>")
}
Expand Down Expand Up @@ -639,7 +640,7 @@ mod tests {
builder.append_value("hello");
let (metadata, value) = builder.finish();
let variant = Variant::try_new(&metadata, &value).unwrap();
assert_eq!(variant, Variant::ShortString("hello"));
assert_eq!(variant, Variant::ShortString(ShortString("hello")));
}

{
Expand Down Expand Up @@ -688,7 +689,7 @@ mod tests {
assert_eq!(val1, Variant::Int8(2));

let val2 = list.get(2).unwrap();
assert_eq!(val2, Variant::ShortString("test"));
assert_eq!(val2, Variant::ShortString(ShortString("test")));
}
_ => panic!("Expected an array variant, got: {:?}", variant),
}
Expand Down
7 changes: 4 additions & 3 deletions parquet-variant/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
use crate::utils::{array_from_slice, slice_from_slice, string_from_slice};
use crate::ShortString;

use arrow_schema::ArrowError;
use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, Utc};
Expand Down Expand Up @@ -273,10 +274,10 @@ pub(crate) fn decode_long_string(data: &[u8]) -> Result<&str, ArrowError> {
}

/// Decodes a short string from the value section of a variant.
pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result<&str, ArrowError> {
pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result<ShortString, ArrowError> {
let len = (metadata >> 2) as usize;
let string = string_from_slice(data, 0..len)?;
Ok(string)
ShortString::try_new(string)
}

#[cfg(test)]
Expand Down Expand Up @@ -420,7 +421,7 @@ mod tests {
fn test_short_string() -> Result<(), ArrowError> {
let data = [b'H', b'e', b'l', b'l', b'o', b'o'];
let result = decode_short_string(1 | 5 << 2, &data)?;
assert_eq!(result, "Hello");
assert_eq!(result.0, "Hello");
Ok(())
}

Expand Down
70 changes: 58 additions & 12 deletions parquet-variant/src/variant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,37 @@ mod list;
mod metadata;
mod object;

const MAX_SHORT_STRING_SIZE: usize = 0x3F;

/// A Variant [`ShortString`]
///
/// This implementation is a zero cost wrapper over `&str` that ensures
/// the length of the underlying string is a valid Variant short string (63 bytes or less)
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ShortString<'a>(pub(crate) &'a str);

impl<'a> ShortString<'a> {
/// Attempts to interpret `value` as a variant short string value.
///
/// # Validation
///
/// This constructor verifies that `value` is shorter than or equal to `MAX_SHORT_STRING_SIZE`
pub fn try_new(value: &'a str) -> Result<Self, ArrowError> {
if value.len() > MAX_SHORT_STRING_SIZE {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding a comment here that we are indeed supposed to check bytes and not characters, that's a common confusion with "string length"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a great idea -- maybe we can even name the constant MAX_SHORT_STRING_BYTES to make it more self describing

return Err(ArrowError::InvalidArgumentError(format!(
"value is larger than {MAX_SHORT_STRING_SIZE} bytes"
)));
}

Ok(Self(value))
}

/// Returns the underlying Variant short string as a &str
pub fn as_str(&self) -> &'a str {
self.0
}
}

/// Represents a [Parquet Variant]
///
/// The lifetimes `'m` and `'v` are for metadata and value buffers, respectively.
Expand Down Expand Up @@ -85,7 +116,7 @@ mod object;
///
/// ## Creating `Variant` from Rust Types
/// ```
/// # use parquet_variant::Variant;
/// use parquet_variant::{Variant};
/// // variants can be directly constructed
/// let variant = Variant::Int32(123);
/// // or constructed via `From` impls
Expand All @@ -98,7 +129,7 @@ mod object;
/// let value = [0x09, 0x48, 0x49];
/// // parse the header metadata
/// assert_eq!(
/// Variant::ShortString("HI"),
/// Variant::from("HI"),
/// Variant::try_new(&metadata, &value).unwrap()
/// );
/// ```
Expand Down Expand Up @@ -152,7 +183,7 @@ pub enum Variant<'m, 'v> {
/// Primitive (type_id=1): STRING
String(&'v str),
/// Short String (type_id=2): STRING
ShortString(&'v str),
ShortString(ShortString<'v>),
// need both metadata & value
/// Object (type_id=3): N/A
Object(VariantObject<'m, 'v>),
Expand All @@ -165,12 +196,12 @@ impl<'m, 'v> Variant<'m, 'v> {
///
/// # Example
/// ```
/// # use parquet_variant::{Variant, VariantMetadata};
/// # use parquet_variant::{Variant, VariantMetadata, ShortString};
/// let metadata = [0x01, 0x00, 0x00];
/// let value = [0x09, 0x48, 0x49];
/// // parse the header metadata
/// assert_eq!(
/// Variant::ShortString("HI"),
/// Variant::from("HI"),
/// Variant::try_new(&metadata, &value).unwrap()
/// );
/// ```
Expand All @@ -189,7 +220,7 @@ impl<'m, 'v> Variant<'m, 'v> {
/// // parse the header metadata first
/// let metadata = VariantMetadata::try_new(&metadata).unwrap();
/// assert_eq!(
/// Variant::ShortString("HI"),
/// Variant::from("HI"),
/// Variant::try_new_with_metadata(metadata, &value).unwrap()
/// );
/// ```
Expand Down Expand Up @@ -428,11 +459,11 @@ impl<'m, 'v> Variant<'m, 'v> {
/// # Examples
///
/// ```
/// use parquet_variant::Variant;
/// use parquet_variant::{Variant};
///
/// // you can extract a string from string variants
/// let s = "hello!";
/// let v1 = Variant::ShortString(s);
/// let v1 = Variant::from(s);
/// assert_eq!(v1.as_string(), Some(s));
///
/// // but not from other variants
Expand All @@ -441,7 +472,7 @@ impl<'m, 'v> Variant<'m, 'v> {
/// ```
pub fn as_string(&'v self) -> Option<&'v str> {
match self {
Variant::String(s) | Variant::ShortString(s) => Some(s),
Variant::String(s) | Variant::ShortString(ShortString(s)) => Some(s),
_ => None,
}
}
Expand Down Expand Up @@ -861,10 +892,25 @@ impl<'v> From<&'v [u8]> for Variant<'_, 'v> {

impl<'v> From<&'v str> for Variant<'_, 'v> {
fn from(value: &'v str) -> Self {
if value.len() < 64 {
Variant::ShortString(value)
} else {
if value.len() > MAX_SHORT_STRING_SIZE {
Variant::String(value)
} else {
Variant::ShortString(ShortString(value))
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_construct_short_string() {
let short_string = ShortString::try_new("norm").expect("should fit in short string");
assert_eq!(short_string.as_str(), "norm");

let long_string = "a".repeat(MAX_SHORT_STRING_SIZE + 1);
let res = ShortString::try_new(&long_string);
assert!(res.is_err());
}
}
17 changes: 13 additions & 4 deletions parquet-variant/tests/variant_interop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::fs;
use std::path::{Path, PathBuf};

use chrono::NaiveDate;
use parquet_variant::{Variant, VariantBuilder};
use parquet_variant::{ShortString, Variant, VariantBuilder};

fn cases_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR"))
Expand Down Expand Up @@ -76,7 +76,7 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> {
("primitive_string", Variant::String("This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥\u{fe0f}, 🎣 and 🤦!!")),
("primitive_timestamp", Variant::TimestampMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(16, 34, 56, 780).unwrap().and_utc())),
("primitive_timestampntz", Variant::TimestampNtzMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(12, 34, 56, 780).unwrap())),
("short_string", Variant::ShortString("Less than 64 bytes (❤\u{fe0f} with utf8)")),
("short_string", Variant::ShortString(ShortString::try_new("Less than 64 bytes (❤\u{fe0f} with utf8)").unwrap())),
]
}
#[test]
Expand Down Expand Up @@ -130,11 +130,20 @@ fn variant_object_primitive() {
),
("int_field", Variant::Int8(1)),
("null_field", Variant::Null),
("string_field", Variant::ShortString("Apache Parquet")),
(
"string_field",
Variant::ShortString(
ShortString::try_new("Apache Parquet")
.expect("value should fit inside a short string"),
),
),
(
// apparently spark wrote this as a string (not a timestamp)
"timestamp_field",
Variant::ShortString("2025-04-16T12:34:56.78"),
Variant::ShortString(
ShortString::try_new("2025-04-16T12:34:56.78")
.expect("value should fit inside a short string"),
),
),
];
let actual_fields: Vec<_> = variant_object.iter().collect();
Expand Down
Loading