From 7fc663c2e40be2928778102386bbf76962dd2cdc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 29 Dec 2023 16:53:31 -0700 Subject: [PATCH] Implement serde for CSV and Parquet FileSinkExec (#8646) * Add serde for Csv and Parquet sink * Add tests * parquet test passes * save progress * add compression type to csv serde * remove hard-coded compression from CSV serde --- .../core/src/datasource/file_format/csv.rs | 11 +- .../src/datasource/file_format/parquet.rs | 9 +- datafusion/proto/proto/datafusion.proto | 40 +- datafusion/proto/src/generated/pbjson.rs | 517 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 59 +- datafusion/proto/src/logical_plan/mod.rs | 43 +- .../proto/src/physical_plan/from_proto.rs | 38 +- datafusion/proto/src/physical_plan/mod.rs | 91 +++ .../proto/src/physical_plan/to_proto.rs | 46 +- .../tests/cases/roundtrip_physical_plan.rs | 125 ++++- 10 files changed, 922 insertions(+), 57 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index d4e63904bdd4..7a0af3ff0809 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -437,7 +437,7 @@ impl BatchSerializer for CsvSerializer { } /// Implements [`DataSink`] for writing to a CSV file. -struct CsvSink { +pub struct CsvSink { /// Config options for writing data config: FileSinkConfig, } @@ -461,9 +461,16 @@ impl DisplayAs for CsvSink { } impl CsvSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } + async fn multipartput_all( &self, data: SendableRecordBatchStream, diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 7044acccd6dc..9729bfa163af 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -621,7 +621,7 @@ async fn fetch_statistics( } /// Implements [`DataSink`] for writing to a parquet file. -struct ParquetSink { +pub struct ParquetSink { /// Config options for writing data config: FileSinkConfig, } @@ -645,10 +645,15 @@ impl DisplayAs for ParquetSink { } impl ParquetSink { - fn new(config: FileSinkConfig) -> Self { + /// Create from config. + pub fn new(config: FileSinkConfig) -> Self { Self { config } } + /// Retrieve the inner [`FileSinkConfig`]. + pub fn config(&self) -> &FileSinkConfig { + &self.config + } /// Converts table schema to writer schema, which may differ in the case /// of hive style partitioning where some columns are removed from the /// underlying files. diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 59b82efcbb43..d5f8397aa30c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1187,6 +1187,8 @@ message PhysicalPlanNode { SymmetricHashJoinExecNode symmetric_hash_join = 25; InterleaveExecNode interleave = 26; PlaceholderRowExecNode placeholder_row = 27; + CsvSinkExecNode csv_sink = 28; + ParquetSinkExecNode parquet_sink = 29; } } @@ -1220,20 +1222,22 @@ message ParquetWriterOptions { } message CsvWriterOptions { + // Compression type + CompressionTypeVariant compression = 1; // Optional column delimiter. Defaults to `b','` - string delimiter = 1; + string delimiter = 2; // Whether to write column names as file headers. Defaults to `true` - bool has_header = 2; + bool has_header = 3; // Optional date format for date arrays - string date_format = 3; + string date_format = 4; // Optional datetime format for datetime arrays - string datetime_format = 4; + string datetime_format = 5; // Optional timestamp format for timestamp arrays - string timestamp_format = 5; + string timestamp_format = 6; // Optional time format for time arrays - string time_format = 6; + string time_format = 7; // Optional value to represent null - string null_value = 7; + string null_value = 8; } message WriterProperties { @@ -1270,6 +1274,28 @@ message JsonSinkExecNode { PhysicalSortExprNodeCollection sort_order = 4; } +message CsvSink { + FileSinkConfig config = 1; +} + +message CsvSinkExecNode { + PhysicalPlanNode input = 1; + CsvSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + +message ParquetSink { + FileSinkConfig config = 1; +} + +message ParquetSinkExecNode { + PhysicalPlanNode input = 1; + ParquetSink sink = 2; + Schema sink_schema = 3; + PhysicalSortExprNodeCollection sort_order = 4; +} + message PhysicalExtensionNode { bytes node = 1; repeated PhysicalPlanNode inputs = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 956244ffdbc2..12e834d75adf 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5151,6 +5151,241 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CsvSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(CsvSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.CsvSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CsvSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CsvSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(CsvSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CsvWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -5159,6 +5394,9 @@ impl serde::Serialize for CsvWriterOptions { { use serde::ser::SerializeStruct; let mut len = 0; + if self.compression != 0 { + len += 1; + } if !self.delimiter.is_empty() { len += 1; } @@ -5181,6 +5419,11 @@ impl serde::Serialize for CsvWriterOptions { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } if !self.delimiter.is_empty() { struct_ser.serialize_field("delimiter", &self.delimiter)?; } @@ -5212,6 +5455,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "compression", "delimiter", "has_header", "hasHeader", @@ -5229,6 +5473,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { #[allow(clippy::enum_variant_names)] enum GeneratedField { + Compression, Delimiter, HasHeader, DateFormat, @@ -5257,6 +5502,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { E: serde::de::Error, { match value { + "compression" => Ok(GeneratedField::Compression), "delimiter" => Ok(GeneratedField::Delimiter), "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), @@ -5283,6 +5529,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { where V: serde::de::MapAccess<'de>, { + let mut compression__ = None; let mut delimiter__ = None; let mut has_header__ = None; let mut date_format__ = None; @@ -5292,6 +5539,12 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { let mut null_value__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } GeneratedField::Delimiter => { if delimiter__.is_some() { return Err(serde::de::Error::duplicate_field("delimiter")); @@ -5337,6 +5590,7 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { } } Ok(CsvWriterOptions { + compression: compression__.unwrap_or_default(), delimiter: delimiter__.unwrap_or_default(), has_header: has_header__.unwrap_or_default(), date_format: date_format__.unwrap_or_default(), @@ -15398,6 +15652,241 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ParquetSink { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.config.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetSink { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "config", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Config, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "config" => Ok(GeneratedField::Config), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetSink; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetSink") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut config__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + } + } + Ok(ParquetSink { + config: config__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSink", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetSinkExecNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + Sink, + SinkSchema, + SortOrder, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetSinkExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ParquetSinkExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); + } + sink__ = map_.next_value()?; + } + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); + } + sink_schema__ = map_.next_value()?; + } + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); + } + sort_order__ = map_.next_value()?; + } + } + } + Ok(ParquetSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSinkExecNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for ParquetWriterOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -18484,6 +18973,12 @@ impl serde::Serialize for PhysicalPlanNode { physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { struct_ser.serialize_field("placeholderRow", v)?; } + physical_plan_node::PhysicalPlanType::CsvSink(v) => { + struct_ser.serialize_field("csvSink", v)?; + } + physical_plan_node::PhysicalPlanType::ParquetSink(v) => { + struct_ser.serialize_field("parquetSink", v)?; + } } } struct_ser.end() @@ -18535,6 +19030,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "interleave", "placeholder_row", "placeholderRow", + "csv_sink", + "csvSink", + "parquet_sink", + "parquetSink", ]; #[allow(clippy::enum_variant_names)] @@ -18565,6 +19064,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { SymmetricHashJoin, Interleave, PlaceholderRow, + CsvSink, + ParquetSink, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18612,6 +19113,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), "interleave" => Ok(GeneratedField::Interleave), "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), + "csvSink" | "csv_sink" => Ok(GeneratedField::CsvSink), + "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18814,6 +19317,20 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { return Err(serde::de::Error::duplicate_field("placeholderRow")); } physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) +; + } + GeneratedField::CsvSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvSink) +; + } + GeneratedField::ParquetSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 32e892e663ef..4ee0b70325ca 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1566,7 +1566,7 @@ pub mod owned_table_reference { pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" )] pub physical_plan_type: ::core::option::Option, } @@ -1629,6 +1629,10 @@ pub mod physical_plan_node { Interleave(super::InterleaveExecNode), #[prost(message, tag = "27")] PlaceholderRow(super::PlaceholderRowExecNode), + #[prost(message, tag = "28")] + CsvSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "29")] + ParquetSink(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1673,26 +1677,29 @@ pub struct ParquetWriterOptions { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvWriterOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, /// Optional column delimiter. Defaults to `b','` - #[prost(string, tag = "1")] + #[prost(string, tag = "2")] pub delimiter: ::prost::alloc::string::String, /// Whether to write column names as file headers. Defaults to `true` - #[prost(bool, tag = "2")] + #[prost(bool, tag = "3")] pub has_header: bool, /// Optional date format for date arrays - #[prost(string, tag = "3")] + #[prost(string, tag = "4")] pub date_format: ::prost::alloc::string::String, /// Optional datetime format for datetime arrays - #[prost(string, tag = "4")] + #[prost(string, tag = "5")] pub datetime_format: ::prost::alloc::string::String, /// Optional timestamp format for timestamp arrays - #[prost(string, tag = "5")] + #[prost(string, tag = "6")] pub timestamp_format: ::prost::alloc::string::String, /// Optional time format for time arrays - #[prost(string, tag = "6")] + #[prost(string, tag = "7")] pub time_format: ::prost::alloc::string::String, /// Optional value to represent null - #[prost(string, tag = "7")] + #[prost(string, tag = "8")] pub null_value: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] @@ -1753,6 +1760,42 @@ pub struct JsonSinkExecNode { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExtensionNode { #[prost(bytes = "vec", tag = "1")] pub node: ::prost::alloc::vec::Vec, diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index dbed0252d051..5ee88c3d5328 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1648,28 +1648,10 @@ impl AsLogicalPlan for LogicalPlanNode { match opt.as_ref() { FileTypeWriterOptions::CSV(csv_opts) => { let csv_options = &csv_opts.writer_options; - let csv_writer_options = protobuf::CsvWriterOptions { - delimiter: (csv_options.delimiter() as char) - .to_string(), - has_header: csv_options.header(), - date_format: csv_options - .date_format() - .unwrap_or("") - .to_owned(), - datetime_format: csv_options - .datetime_format() - .unwrap_or("") - .to_owned(), - timestamp_format: csv_options - .timestamp_format() - .unwrap_or("") - .to_owned(), - time_format: csv_options - .time_format() - .unwrap_or("") - .to_owned(), - null_value: csv_options.null().to_owned(), - }; + let csv_writer_options = csv_writer_options_to_proto( + csv_options, + (&csv_opts.compression).into(), + ); let csv_options = file_type_writer_options::FileType::CsvOptions( csv_writer_options, @@ -1724,6 +1706,23 @@ impl AsLogicalPlan for LogicalPlanNode { } } +pub(crate) fn csv_writer_options_to_proto( + csv_options: &WriterBuilder, + compression: &CompressionTypeVariant, +) -> protobuf::CsvWriterOptions { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::CsvWriterOptions { + compression: compression.into(), + delimiter: (csv_options.delimiter() as char).to_string(), + has_header: csv_options.header(), + date_format: csv_options.date_format().unwrap_or("").to_owned(), + datetime_format: csv_options.datetime_format().unwrap_or("").to_owned(), + timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), + time_format: csv_options.time_format().unwrap_or("").to_owned(), + null_value: csv_options.null().to_owned(), + } +} + pub(crate) fn csv_writer_options_from_proto( writer_options: &protobuf::CsvWriterOptions, ) -> Result { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 6f1e811510c6..8ad6d679df4d 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -22,7 +22,10 @@ use std::sync::Arc; use arrow::compute::SortOptions; use datafusion::arrow::datatypes::Schema; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::JsonSink; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{FileRange, ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; @@ -713,6 +716,23 @@ impl TryFrom<&protobuf::JsonSink> for JsonSink { } } +#[cfg(feature = "parquet")] +impl TryFrom<&protobuf::ParquetSink> for ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::ParquetSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + +impl TryFrom<&protobuf::CsvSink> for CsvSink { + type Error = DataFusionError; + + fn try_from(value: &protobuf::CsvSink) -> Result { + Ok(Self::new(convert_required!(value.config)?)) + } +} + impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { type Error = DataFusionError; @@ -768,16 +788,16 @@ impl TryFrom<&protobuf::FileTypeWriterOptions> for FileTypeWriterOptions { .file_type .as_ref() .ok_or_else(|| proto_error("Missing required file_type field in protobuf"))?; + match file_type { - protobuf::file_type_writer_options::FileType::JsonOptions(opts) => Ok( - Self::JSON(JsonWriterOptions::new(opts.compression().into())), - ), - protobuf::file_type_writer_options::FileType::CsvOptions(opt) => { - let write_options = csv_writer_options_from_proto(opt)?; - Ok(Self::CSV(CsvWriterOptions::new( - write_options, - CompressionTypeVariant::UNCOMPRESSED, - ))) + protobuf::file_type_writer_options::FileType::JsonOptions(opts) => { + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::JSON(JsonWriterOptions::new(compression))) + } + protobuf::file_type_writer_options::FileType::CsvOptions(opts) => { + let write_options = csv_writer_options_from_proto(opts)?; + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(Self::CSV(CsvWriterOptions::new(write_options, compression))) } protobuf::file_type_writer_options::FileType::ParquetOptions(opt) => { let props = opt.writer_properties.clone().unwrap_or_default(); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 24ede3fcaf62..95becb3fe4b3 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -21,9 +21,12 @@ use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; use datafusion::datasource::file_format::json::JsonSink; #[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; +#[cfg(feature = "parquet")] use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; @@ -921,6 +924,68 @@ impl AsExecutionPlan for PhysicalPlanNode { sort_order, ))) } + PhysicalPlanType::CsvSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: CsvSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } + PhysicalPlanType::ParquetSink(sink) => { + let input = + into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + + let data_sink: ParquetSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = convert_required!(sink.sink_schema)?; + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + collection + .physical_sort_expr_nodes + .iter() + .map(|proto| { + parse_physical_sort_expr(proto, registry, &sink_schema) + .map(Into::into) + }) + .collect::>>() + }) + .transpose()?; + Ok(Arc::new(FileSinkExec::new( + input, + Arc::new(data_sink), + Arc::new(sink_schema), + sort_order, + ))) + } } } @@ -1678,6 +1743,32 @@ impl AsExecutionPlan for PhysicalPlanNode { }); } + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::CsvSink(Box::new( + protobuf::CsvSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + + if let Some(sink) = exec.sink().as_any().downcast_ref::() { + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( + protobuf::ParquetSinkExecNode { + input: Some(Box::new(input)), + sink: Some(sink.try_into()?), + sink_schema: Some(exec.schema().as_ref().try_into()?), + sort_order, + }, + ))), + }); + } + // If unknown DataSink then let extension handle it } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e9cdb34cf1b9..f4e3f9e4dca7 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -28,7 +28,12 @@ use crate::protobuf::{ ScalarValue, }; +#[cfg(feature = "parquet")] +use datafusion::datasource::file_format::parquet::ParquetSink; + +use crate::logical_plan::{csv_writer_options_to_proto, writer_properties_to_proto}; use datafusion::datasource::{ + file_format::csv::CsvSink, file_format::json::JsonSink, listing::{FileRange, PartitionedFile}, physical_plan::FileScanConfig, @@ -814,6 +819,27 @@ impl TryFrom<&JsonSink> for protobuf::JsonSink { } } +impl TryFrom<&CsvSink> for protobuf::CsvSink { + type Error = DataFusionError; + + fn try_from(value: &CsvSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + +#[cfg(feature = "parquet")] +impl TryFrom<&ParquetSink> for protobuf::ParquetSink { + type Error = DataFusionError; + + fn try_from(value: &ParquetSink) -> Result { + Ok(Self { + config: Some(value.config().try_into()?), + }) + } +} + impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { type Error = DataFusionError; @@ -870,13 +896,21 @@ impl TryFrom<&FileTypeWriterOptions> for protobuf::FileTypeWriterOptions { fn try_from(opts: &FileTypeWriterOptions) -> Result { let file_type = match opts { #[cfg(feature = "parquet")] - FileTypeWriterOptions::Parquet(ParquetWriterOptions { - writer_options: _, - }) => return not_impl_err!("Parquet file sink protobuf serialization"), + FileTypeWriterOptions::Parquet(ParquetWriterOptions { writer_options }) => { + protobuf::file_type_writer_options::FileType::ParquetOptions( + protobuf::ParquetWriterOptions { + writer_properties: Some(writer_properties_to_proto( + writer_options, + )), + }, + ) + } FileTypeWriterOptions::CSV(CsvWriterOptions { - writer_options: _, - compression: _, - }) => return not_impl_err!("CSV file sink protobuf serialization"), + writer_options, + compression, + }) => protobuf::file_type_writer_options::FileType::CsvOptions( + csv_writer_options_to_proto(writer_options, compression), + ), FileTypeWriterOptions::JSON(JsonWriterOptions { compression }) => { let compression: protobuf::CompressionTypeVariant = compression.into(); protobuf::file_type_writer_options::FileType::JsonOptions( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 2eb04ab6cbab..27ac5d122f83 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use arrow::csv::WriterBuilder; use std::ops::Deref; use std::sync::Arc; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; +use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::JsonSink; +use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ @@ -31,6 +34,7 @@ use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; +use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -62,7 +66,9 @@ use datafusion::physical_plan::{ }; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; +use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; +use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{FileTypeWriterOptions, Result}; @@ -73,7 +79,23 @@ use datafusion_expr::{ use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. fn roundtrip_test(exec_plan: Arc) -> Result<()> { + let _ = roundtrip_test_and_return(exec_plan); + Ok(()) +} + +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test method returns the final plan after serde so that it can be inspected +/// farther in tests. +fn roundtrip_test_and_return( + exec_plan: Arc, +) -> Result> { let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; let proto: protobuf::PhysicalPlanNode = @@ -84,9 +106,15 @@ fn roundtrip_test(exec_plan: Arc) -> Result<()> { .try_into_physical_plan(&ctx, runtime.deref(), &codec) .expect("from proto"); assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); - Ok(()) + Ok(result_exec_plan) } +/// Perform a serde roundtrip and assert that the string representation of the before and after plans +/// are identical. Note that this often isn't sufficient to guarantee that no information is +/// lost during serde because the string representation of a plan often only shows a subset of state. +/// +/// This version of the roundtrip_test function accepts a SessionContext, which is required when +/// performing serde on some plans. fn roundtrip_test_with_context( exec_plan: Arc, ctx: SessionContext, @@ -755,6 +783,101 @@ fn roundtrip_json_sink() -> Result<()> { ))) } +#[test] +fn roundtrip_csv_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::CSV(CsvWriterOptions::new( + WriterBuilder::default(), + CompressionTypeVariant::ZSTD, + )), + }; + let data_sink = Arc::new(CsvSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + let roundtrip_plan = roundtrip_test_and_return(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) + .unwrap(); + + let roundtrip_plan = roundtrip_plan + .as_any() + .downcast_ref::() + .unwrap(); + let csv_sink = roundtrip_plan + .sink() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + CompressionTypeVariant::ZSTD, + csv_sink + .config() + .file_type_writer_options + .try_into_csv() + .unwrap() + .compression + ); + + Ok(()) +} + +#[test] +fn roundtrip_parquet_sink() -> Result<()> { + let field_a = Field::new("plan_type", DataType::Utf8, false); + let field_b = Field::new("plan", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let input = Arc::new(PlaceholderRowExec::new(schema.clone())); + + let file_sink_config = FileSinkConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], + single_file_output: true, + overwrite: true, + file_type_writer_options: FileTypeWriterOptions::Parquet( + ParquetWriterOptions::new(WriterProperties::default()), + ), + }; + let data_sink = Arc::new(ParquetSink::new(file_sink_config)); + let sort_order = vec![PhysicalSortRequirement::new( + Arc::new(Column::new("plan_type", 0)), + Some(SortOptions { + descending: true, + nulls_first: false, + }), + )]; + + roundtrip_test(Arc::new(FileSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + ))) +} + #[test] fn roundtrip_sym_hash_join() -> Result<()> { let field_a = Field::new("col", DataType::Int64, false);