diff --git a/arrow-json/src/reader/struct_array.rs b/arrow-json/src/reader/struct_array.rs index 262097ace39..df0d5b8a5b8 100644 --- a/arrow-json/src/reader/struct_array.rs +++ b/arrow-json/src/reader/struct_array.rs @@ -21,6 +21,52 @@ use arrow_array::builder::BooleanBufferBuilder; use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Fields}; +use std::collections::HashMap; + +/// Reusable buffer for tape positions, indexed by (field_idx, row_idx). +/// A value of 0 indicates the field is absent for that row. +struct FieldTapePositions { + data: Vec, + row_count: usize, +} + +impl FieldTapePositions { + fn new() -> Self { + Self { + data: Vec::new(), + row_count: 0, + } + } + + fn resize(&mut self, field_count: usize, row_count: usize) -> Result<(), ArrowError> { + let total_len = field_count.checked_mul(row_count).ok_or_else(|| { + ArrowError::JsonError(format!( + "FieldTapePositions buffer size overflow for rows={row_count} fields={field_count}" + )) + })?; + self.data.clear(); + self.data.resize(total_len, 0); + self.row_count = row_count; + Ok(()) + } + + fn try_set(&mut self, field_idx: usize, row_idx: usize, pos: u32) -> Option<()> { + let idx = field_idx + .checked_mul(self.row_count)? + .checked_add(row_idx)?; + *self.data.get_mut(idx)? = pos; + Some(()) + } + + fn set(&mut self, field_idx: usize, row_idx: usize, pos: u32) { + self.data[field_idx * self.row_count + row_idx] = pos; + } + + fn field_positions(&self, field_idx: usize) -> &[u32] { + let start = field_idx * self.row_count; + &self.data[start..start + self.row_count] + } +} pub struct StructArrayDecoder { data_type: DataType, @@ -28,6 +74,8 @@ pub struct StructArrayDecoder { strict_mode: bool, is_nullable: bool, struct_mode: StructMode, + field_name_to_index: Option>, + field_tape_positions: FieldTapePositions, } impl StructArrayDecoder { @@ -38,22 +86,31 @@ impl StructArrayDecoder { is_nullable: bool, struct_mode: StructMode, ) -> Result { - let decoders = struct_fields(&data_type) - .iter() - .map(|f| { - // If this struct nullable, need to permit nullability in child array - // StructArrayDecoder::decode verifies that if the child is not nullable - // it doesn't contain any nulls not masked by its parent - let nullable = f.is_nullable() || is_nullable; - make_decoder( - f.data_type().clone(), - coerce_primitive, - strict_mode, - nullable, - struct_mode, - ) - }) - .collect::, ArrowError>>()?; + let (decoders, field_name_to_index) = { + let fields = struct_fields(&data_type); + let decoders = fields + .iter() + .map(|f| { + // If this struct nullable, need to permit nullability in child array + // StructArrayDecoder::decode verifies that if the child is not nullable + // it doesn't contain any nulls not masked by its parent + let nullable = f.is_nullable() || is_nullable; + make_decoder( + f.data_type().clone(), + coerce_primitive, + strict_mode, + nullable, + struct_mode, + ) + }) + .collect::, ArrowError>>()?; + let field_name_to_index = if struct_mode == StructMode::ObjectOnly { + build_field_index(fields) + } else { + None + }; + (decoders, field_name_to_index) + }; Ok(Self { data_type, @@ -61,6 +118,8 @@ impl StructArrayDecoder { strict_mode, is_nullable, struct_mode, + field_name_to_index, + field_tape_positions: FieldTapePositions::new(), }) } } @@ -68,89 +127,99 @@ impl StructArrayDecoder { impl ArrayDecoder for StructArrayDecoder { fn decode(&mut self, tape: &Tape<'_>, pos: &[u32]) -> Result { let fields = struct_fields(&self.data_type); - let mut child_pos: Vec<_> = (0..fields.len()).map(|_| vec![0; pos.len()]).collect(); - + let row_count = pos.len(); + let field_count = fields.len(); + self.field_tape_positions.resize(field_count, row_count)?; let mut nulls = self .is_nullable .then(|| BooleanBufferBuilder::new(pos.len())); - // We avoid having the match on self.struct_mode inside the hot loop for performance - // TODO: Investigate how to extract duplicated logic. - match self.struct_mode { - StructMode::ObjectOnly => { - for (row, p) in pos.iter().enumerate() { - let end_idx = match (tape.get(*p), nulls.as_mut()) { - (TapeElement::StartObject(end_idx), None) => end_idx, - (TapeElement::StartObject(end_idx), Some(nulls)) => { - nulls.append(true); - end_idx - } - (TapeElement::Null, Some(nulls)) => { - nulls.append(false); - continue; - } - (_, _) => return Err(tape.error(*p, "{")), - }; - - let mut cur_idx = *p + 1; - while cur_idx < end_idx { - // Read field name - let field_name = match tape.get(cur_idx) { - TapeElement::String(s) => tape.get_string(s), - _ => return Err(tape.error(cur_idx, "field name")), + { + // We avoid having the match on self.struct_mode inside the hot loop for performance + // TODO: Investigate how to extract duplicated logic. + match self.struct_mode { + StructMode::ObjectOnly => { + for (row, p) in pos.iter().enumerate() { + let end_idx = match (tape.get(*p), nulls.as_mut()) { + (TapeElement::StartObject(end_idx), None) => end_idx, + (TapeElement::StartObject(end_idx), Some(nulls)) => { + nulls.append(true); + end_idx + } + (TapeElement::Null, Some(nulls)) => { + nulls.append(false); + continue; + } + (_, _) => return Err(tape.error(*p, "{")), }; - // Update child pos if match found - match fields.iter().position(|x| x.name() == field_name) { - Some(field_idx) => child_pos[field_idx][row] = cur_idx + 1, - None => { - if self.strict_mode { - return Err(ArrowError::JsonError(format!( - "column '{field_name}' missing from schema", - ))); + let mut cur_idx = *p + 1; + while cur_idx < end_idx { + // Read field name + let field_name = match tape.get(cur_idx) { + TapeElement::String(s) => tape.get_string(s), + _ => return Err(tape.error(cur_idx, "field name")), + }; + + // Update child pos if match found + let field_idx = match &self.field_name_to_index { + Some(map) => map.get(field_name).copied(), + None => fields.iter().position(|x| x.name() == field_name), + }; + match field_idx { + Some(field_idx) => { + self.field_tape_positions.set(field_idx, row, cur_idx + 1); + } + None => { + if self.strict_mode { + return Err(ArrowError::JsonError(format!( + "column '{field_name}' missing from schema", + ))); + } } } + // Advance to next field + cur_idx = tape.next(cur_idx + 1, "field value")?; } - // Advance to next field - cur_idx = tape.next(cur_idx + 1, "field value")?; } } - } - StructMode::ListOnly => { - for (row, p) in pos.iter().enumerate() { - let end_idx = match (tape.get(*p), nulls.as_mut()) { - (TapeElement::StartList(end_idx), None) => end_idx, - (TapeElement::StartList(end_idx), Some(nulls)) => { - nulls.append(true); - end_idx - } - (TapeElement::Null, Some(nulls)) => { - nulls.append(false); - continue; - } - (_, _) => return Err(tape.error(*p, "[")), - }; + StructMode::ListOnly => { + for (row, p) in pos.iter().enumerate() { + let end_idx = match (tape.get(*p), nulls.as_mut()) { + (TapeElement::StartList(end_idx), None) => end_idx, + (TapeElement::StartList(end_idx), Some(nulls)) => { + nulls.append(true); + end_idx + } + (TapeElement::Null, Some(nulls)) => { + nulls.append(false); + continue; + } + (_, _) => return Err(tape.error(*p, "[")), + }; - let mut cur_idx = *p + 1; - let mut entry_idx = 0; - while cur_idx < end_idx { - if entry_idx >= fields.len() { + let mut cur_idx = *p + 1; + let mut entry_idx = 0; + while cur_idx < end_idx { + self.field_tape_positions + .try_set(entry_idx, row, cur_idx) + .ok_or_else(|| { + ArrowError::JsonError(format!( + "found extra columns for {} fields", + fields.len() + )) + })?; + entry_idx += 1; + // Advance to next field + cur_idx = tape.next(cur_idx, "field value")?; + } + if entry_idx != fields.len() { return Err(ArrowError::JsonError(format!( - "found extra columns for {} fields", + "found {} columns for {} fields", + entry_idx, fields.len() ))); } - child_pos[entry_idx][row] = cur_idx; - entry_idx += 1; - // Advance to next field - cur_idx = tape.next(cur_idx, "field value")?; - } - if entry_idx != fields.len() { - return Err(ArrowError::JsonError(format!( - "found {} columns for {} fields", - entry_idx, - fields.len() - ))); } } } @@ -159,10 +228,11 @@ impl ArrayDecoder for StructArrayDecoder { let child_data = self .decoders .iter_mut() - .zip(child_pos) + .enumerate() .zip(fields) - .map(|((d, pos), f)| { - d.decode(tape, &pos).map_err(|e| match e { + .map(|((field_idx, d), f)| { + let pos = self.field_tape_positions.field_positions(field_idx); + d.decode(tape, pos).map_err(|e| match e { ArrowError::JsonError(s) => { ArrowError::JsonError(format!("whilst decoding field '{}': {s}", f.name())) } @@ -205,3 +275,20 @@ fn struct_fields(data_type: &DataType) -> &Fields { _ => unreachable!(), } } + +fn build_field_index(fields: &Fields) -> Option> { + // Heuristic threshold: for small field counts, linear scan avoids HashMap overhead. + const FIELD_INDEX_LINEAR_THRESHOLD: usize = 16; + if fields.len() < FIELD_INDEX_LINEAR_THRESHOLD { + return None; + } + + let mut map = HashMap::with_capacity(fields.len()); + for (idx, field) in fields.iter().enumerate() { + let name = field.name(); + if !map.contains_key(name) { + map.insert(name.to_string(), idx); + } + } + Some(map) +}