diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index c1d7a697ab0..77390fd7d61 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -347,7 +347,9 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, description="Authenticate using the BasicAuth protobuf."), Scenario( "middleware", - description="Ensure headers are propagated via middleware."), + description="Ensure headers are propagated via middleware.", + skip="Rust" # TODO(ARROW-10961): tonic upgrade needed + ), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) diff --git a/dev/archery/archery/integration/tester_rust.py b/dev/archery/archery/integration/tester_rust.py index 23c2d37386a..bca80ebae3c 100644 --- a/dev/archery/archery/integration/tester_rust.py +++ b/dev/archery/archery/integration/tester_rust.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. +import contextlib import os +import subprocess from .tester import Tester from .util import run_cmd, ARROW_ROOT_DEFAULT, log @@ -24,8 +26,8 @@ class RustTester(Tester): PRODUCER = True CONSUMER = True - # FLIGHT_SERVER = True - # FLIGHT_CLIENT = True + FLIGHT_SERVER = True + FLIGHT_CLIENT = True EXE_PATH = os.path.join(ARROW_ROOT_DEFAULT, 'rust/target/debug') @@ -34,11 +36,11 @@ class RustTester(Tester): STREAM_TO_FILE = os.path.join(EXE_PATH, 'arrow-stream-to-file') FILE_TO_STREAM = os.path.join(EXE_PATH, 'arrow-file-to-stream') - # FLIGHT_SERVER_CMD = [ - # os.path.join(EXE_PATH, 'flight-test-integration-server')] - # FLIGHT_CLIENT_CMD = [ - # os.path.join(EXE_PATH, 'flight-test-integration-client'), - # "-host", "localhost"] + FLIGHT_SERVER_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-server')] + FLIGHT_CLIENT_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-client'), + "--host", "localhost"] name = 'Rust' @@ -72,34 +74,42 @@ def file_to_stream(self, file_path, stream_path): cmd = [self.FILE_TO_STREAM, file_path, '>', stream_path] self.run_shell_command(cmd) - # @contextlib.contextmanager - # def flight_server(self): - # cmd = self.FLIGHT_SERVER_CMD + ['-port=0'] - # if self.debug: - # log(' '.join(cmd)) - # server = subprocess.Popen(cmd, - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE) - # try: - # output = server.stdout.readline().decode() - # if not output.startswith("Server listening on localhost:"): - # server.kill() - # out, err = server.communicate() - # raise RuntimeError( - # "Flight-C++ server did not start properly, " - # "stdout:\n{}\n\nstderr:\n{}\n" - # .format(output + out.decode(), err.decode())) - # port = int(output.split(":")[1]) - # yield port - # finally: - # server.kill() - # server.wait(5) - - # def flight_request(self, port, json_path): - # cmd = self.FLIGHT_CLIENT_CMD + [ - # '-port=' + str(port), - # '-path=' + json_path, - # ] - # if self.debug: - # log(' '.join(cmd)) - # run_cmd(cmd) + @contextlib.contextmanager + def flight_server(self, scenario_name=None): + cmd = self.FLIGHT_SERVER_CMD + ['--port=0'] + if scenario_name: + cmd = cmd + ["--scenario", scenario_name] + if self.debug: + log(' '.join(cmd)) + server = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + try: + output = server.stdout.readline().decode() + if not output.startswith("Server listening on localhost:"): + server.kill() + out, err = server.communicate() + raise RuntimeError( + "Flight-Rust server did not start properly, " + "stdout:\n{}\n\nstderr:\n{}\n" + .format(output + out.decode(), err.decode())) + port = int(output.split(":")[1]) + yield port + finally: + server.kill() + server.wait(5) + + def flight_request(self, port, json_path=None, scenario_name=None): + cmd = self.FLIGHT_CLIENT_CMD + [ + '--port=' + str(port), + ] + if json_path: + cmd.extend(('--path', json_path)) + elif scenario_name: + cmd.extend(('--scenario', scenario_name)) + else: + raise TypeError("Must provide one of json_path or scenario_name") + + if self.debug: + log(' '.join(cmd)) + run_cmd(cmd) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index c2e01fb6ccc..659668c0baf 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -21,17 +21,18 @@ use std::convert::TryFrom; use crate::{FlightData, SchemaResult}; +use arrow::array::ArrayRef; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result}; -use arrow::ipc::{convert, reader, writer, writer::IpcWriteOptions}; +use arrow::ipc::{convert, reader, writer, writer::EncodedData, writer::IpcWriteOptions}; use arrow::record_batch::RecordBatch; /// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries -/// and values +/// and a `FlightData` representing the bytes of the batch's values pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, -) -> Vec { +) -> (Vec, FlightData) { let data_gen = writer::IpcDataGenerator::default(); let mut dictionary_tracker = writer::DictionaryTracker::new(false); @@ -39,16 +40,20 @@ pub fn flight_data_from_arrow_batch( .encoded_batch(batch, &mut dictionary_tracker, &options) .expect("DictionaryTracker configured above to not error on replacement"); - encoded_dictionaries - .into_iter() - .chain(std::iter::once(encoded_batch)) - .map(|data| FlightData { - flight_descriptor: None, - app_metadata: vec![], + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + (flight_dictionaries, flight_batch) +} + +impl From for FlightData { + fn from(data: EncodedData) -> Self { + FlightData { data_header: data.ipc_message, data_body: data.arrow_data, - }) - .collect() + ..Default::default() + } + } } /// Convert a `Schema` to `SchemaResult` by converting to an IPC message @@ -56,11 +61,8 @@ pub fn flight_schema_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, ) -> SchemaResult { - let data_gen = writer::IpcDataGenerator::default(); - let schema_bytes = data_gen.schema_to_bytes(schema, &options); - SchemaResult { - schema: schema_bytes.ipc_message, + schema: flight_schema_as_flatbuffer(schema, options), } } @@ -69,16 +71,41 @@ pub fn flight_data_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, ) -> FlightData { - let data_gen = writer::IpcDataGenerator::default(); - let schema = data_gen.schema_to_bytes(schema, &options); + let data_header = flight_schema_as_flatbuffer(schema, options); FlightData { - flight_descriptor: None, - app_metadata: vec![], - data_header: schema.ipc_message, - data_body: vec![], + data_header, + ..Default::default() } } +/// Convert a `Schema` to bytes in the format expected in `FlightInfo.schema` +pub fn ipc_message_from_arrow_schema( + arrow_schema: &Schema, + options: &IpcWriteOptions, +) -> Result> { + let encoded_data = flight_schema_as_encoded_data(arrow_schema, options); + + let mut schema = vec![]; + arrow::ipc::writer::write_message(&mut schema, encoded_data, options)?; + Ok(schema) +} + +fn flight_schema_as_flatbuffer( + arrow_schema: &Schema, + options: &IpcWriteOptions, +) -> Vec { + let encoded_data = flight_schema_as_encoded_data(arrow_schema, options); + encoded_data.ipc_message +} + +fn flight_schema_as_encoded_data( + arrow_schema: &Schema, + options: &IpcWriteOptions, +) -> EncodedData { + let data_gen = writer::IpcDataGenerator::default(); + data_gen.schema_to_bytes(arrow_schema, options) +} + /// Try convert `FlightData` into an Arrow Schema /// /// Returns an error if the `FlightData` header is not a valid IPC schema @@ -113,21 +140,12 @@ impl TryFrom<&SchemaResult> for Schema { pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, -) -> Option> { + dictionaries_by_field: &[Option], +) -> Result { // check that the data_header is a record batch message - let res = arrow::ipc::root_as_message(&data.data_header[..]); - - // Catch error. - if let Err(err) = res { - return Some(Err(ArrowError::ParseError(format!( - "Unable to get root as message: {:?}", - err - )))); - } - - let message = res.unwrap(); - - let dictionaries_by_field = Vec::new(); + let message = arrow::ipc::root_as_message(&data.data_header[..]).map_err(|err| { + ArrowError::ParseError(format!("Unable to get root as message: {:?}", err)) + })?; message .header_as_record_batch() @@ -136,17 +154,14 @@ pub fn flight_data_to_arrow_batch( "Unable to convert flight data header to a record batch".to_string(), ) }) - .map_or_else( - |err| Some(Err(err)), - |batch| { - Some(reader::read_record_batch( - &data.data_body, - batch, - schema, - &dictionaries_by_field, - )) - }, - ) + .map(|batch| { + reader::read_record_batch( + &data.data_body, + batch, + schema, + &dictionaries_by_field, + ) + })? } // TODO: add more explicit conversion that exposes flight descriptor and metadata options diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs index 809e7177210..7e6d7962a6e 100644 --- a/rust/arrow/src/ipc/reader.rs +++ b/rust/arrow/src/ipc/reader.rs @@ -473,7 +473,7 @@ pub fn read_record_batch( /// Read the dictionary from the buffer and provided metadata, /// updating the `dictionaries_by_field` with the resulting dictionary -fn read_dictionary( +pub fn read_dictionary( buf: &[u8], batch: ipc::DictionaryBatch, schema: &Schema, diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index 5161d548ca6..52fe178168a 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -554,10 +554,9 @@ pub struct EncodedData { /// Arrow buffers to be written, should be an empty vec for schema messages pub arrow_data: Vec, } - /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written -fn write_message( - mut writer: &mut BufWriter, +pub fn write_message( + mut writer: W, encoded: EncodedData, write_options: &IpcWriteOptions, ) -> Result<(usize, usize)> { @@ -602,7 +601,7 @@ fn write_message( Ok((aligned_size, body_len)) } -fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Result { +fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { let len = data.len() as u32; let pad_len = pad_to_8(len) as u32; let total_len = len + pad_len; @@ -620,7 +619,7 @@ fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Resul /// Write a record batch to the writer, writing the message size before the message /// if the record batch is being written to a stream fn write_continuation( - writer: &mut BufWriter, + mut writer: W, write_options: &IpcWriteOptions, total_len: i32, ) -> Result { diff --git a/rust/benchmarks/README.md b/rust/benchmarks/README.md index 9bff3e2d8ee..2ae035b9fc4 100644 --- a/rust/benchmarks/README.md +++ b/rust/benchmarks/README.md @@ -37,6 +37,7 @@ clone the repository and build the source code. git clone git@github.com:databricks/tpch-dbgen.git cd tpch-dbgen make +export TPCH_DATA=$(pwd) ``` Data can now be generated with the following command. Note that `-s 1` means use Scale Factor 1 or ~1 GB of @@ -63,7 +64,7 @@ This utility does not yet provide support for changing the number of partitions option is to use the following Docker image to perform the conversion from `tbl` files to CSV or Parquet. ```bash -docker run -it ballistacompute/spark-benchmarks:0.4.0-SNAPSHOT +docker run -it ballistacompute/spark-benchmarks:0.4.0-SNAPSHOT -h, --help Show help message Subcommand: convert-tpch diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index 769668c958b..cdffd3593cc 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -52,7 +52,7 @@ struct BenchmarkOpt { concurrency: usize, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "4096")] + #[structopt(short = "s", long = "batch-size", default_value = "32768")] batch_size: usize, /// Path to data files @@ -108,15 +108,14 @@ const TABLES: &[&str] = &[ #[tokio::main] async fn main() -> Result<()> { + env_logger::init(); match TpchOpt::from_args() { - TpchOpt::Benchmark(opt) => benchmark(opt).await, + TpchOpt::Benchmark(opt) => benchmark(opt).await.map(|_| ()), TpchOpt::Convert(opt) => convert_tbl(opt).await, } } -async fn benchmark(opt: BenchmarkOpt) -> Result<()> { - env_logger::init(); - +async fn benchmark(opt: BenchmarkOpt) -> Result> { println!("Running benchmarks with the following options: {:?}", opt); let config = ExecutionConfig::new() .with_concurrency(opt.concurrency) @@ -146,10 +145,11 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<()> { let mut millis = vec![]; // run benchmark + let mut result: Vec = Vec::with_capacity(1); for i in 0..opt.iterations { let start = Instant::now(); let plan = create_logical_plan(&mut ctx, opt.query)?; - execute_query(&mut ctx, &plan, opt.debug).await?; + result = execute_query(&mut ctx, &plan, opt.debug).await?; let elapsed = start.elapsed().as_secs_f64() * 1000.0; millis.push(elapsed as f64); println!("Query {} iteration {} took {:.1} ms", opt.query, i, elapsed); @@ -158,7 +158,7 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<()> { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {} avg time: {:.2} ms", opt.query, avg); - Ok(()) + Ok(result) } fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result { @@ -994,7 +994,7 @@ async fn execute_query( ctx: &mut ExecutionContext, plan: &LogicalPlan, debug: bool, -) -> Result<()> { +) -> Result> { if debug { println!("Logical plan:\n{:?}", plan); } @@ -1007,12 +1007,11 @@ async fn execute_query( if debug { pretty::print_batches(&result)?; } - Ok(()) + Ok(result) } async fn convert_tbl(opt: ConvertOpt) -> Result<()> { let output_root_path = Path::new(&opt.output_path); - for table in TABLES { let start = Instant::now(); let schema = get_schema(table); @@ -1088,13 +1087,14 @@ fn get_table( table_format: &str, ) -> Result> { match table_format { - // dbgen creates .tbl ('|' delimited) files + // dbgen creates .tbl ('|' delimited) files without header "tbl" => { let path = format!("{}/{}.tbl", path, table); let schema = get_schema(table); let options = CsvReadOptions::new() .schema(&schema) .delimiter(b'|') + .has_header(false) .file_extension(".tbl"); Ok(Box::new(CsvFile::try_new(&path, options)?)) @@ -1130,7 +1130,7 @@ fn get_schema(table: &str) -> Schema { Field::new("p_type", DataType::Utf8, false), Field::new("p_size", DataType::Int32, false), Field::new("p_container", DataType::Utf8, false), - Field::new("p_retailprice", DataType::Float64, false), // decimal + Field::new("p_retailprice", DataType::Float64, false), Field::new("p_comment", DataType::Utf8, false), ]), @@ -1140,7 +1140,7 @@ fn get_schema(table: &str) -> Schema { Field::new("s_address", DataType::Utf8, false), Field::new("s_nationkey", DataType::Int32, false), Field::new("s_phone", DataType::Utf8, false), - Field::new("s_acctbal", DataType::Float64, false), // decimal + Field::new("s_acctbal", DataType::Float64, false), Field::new("s_comment", DataType::Utf8, false), ]), @@ -1148,7 +1148,7 @@ fn get_schema(table: &str) -> Schema { Field::new("ps_partkey", DataType::Int32, false), Field::new("ps_suppkey", DataType::Int32, false), Field::new("ps_availqty", DataType::Int32, false), - Field::new("ps_supplycost", DataType::Float64, false), // decimal + Field::new("ps_supplycost", DataType::Float64, false), Field::new("ps_comment", DataType::Utf8, false), ]), @@ -1158,7 +1158,7 @@ fn get_schema(table: &str) -> Schema { Field::new("c_address", DataType::Utf8, false), Field::new("c_nationkey", DataType::Int32, false), Field::new("c_phone", DataType::Utf8, false), - Field::new("c_acctbal", DataType::Float64, false), // decimal + Field::new("c_acctbal", DataType::Float64, false), Field::new("c_mktsegment", DataType::Utf8, false), Field::new("c_comment", DataType::Utf8, false), ]), @@ -1167,7 +1167,7 @@ fn get_schema(table: &str) -> Schema { Field::new("o_orderkey", DataType::Int32, false), Field::new("o_custkey", DataType::Int32, false), Field::new("o_orderstatus", DataType::Utf8, false), - Field::new("o_totalprice", DataType::Float64, false), // decimal + Field::new("o_totalprice", DataType::Float64, false), Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false), Field::new("o_orderpriority", DataType::Utf8, false), Field::new("o_clerk", DataType::Utf8, false), @@ -1180,10 +1180,10 @@ fn get_schema(table: &str) -> Schema { Field::new("l_partkey", DataType::Int32, false), Field::new("l_suppkey", DataType::Int32, false), Field::new("l_linenumber", DataType::Int32, false), - Field::new("l_quantity", DataType::Float64, false), // decimal - Field::new("l_extendedprice", DataType::Float64, false), // decimal - Field::new("l_discount", DataType::Float64, false), // decimal - Field::new("l_tax", DataType::Float64, false), // decimal + Field::new("l_quantity", DataType::Float64, false), + Field::new("l_extendedprice", DataType::Float64, false), + Field::new("l_discount", DataType::Float64, false), + Field::new("l_tax", DataType::Float64, false), Field::new("l_returnflag", DataType::Utf8, false), Field::new("l_linestatus", DataType::Utf8, false), Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false), @@ -1210,3 +1210,404 @@ fn get_schema(table: &str) -> Schema { _ => unimplemented!(), } } + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + use std::sync::Arc; + + use arrow::array::*; + use arrow::record_batch::RecordBatch; + use arrow::util::display::array_value_to_string; + + use datafusion::logical_plan::Expr; + use datafusion::logical_plan::Expr::Cast; + + #[tokio::test] + async fn q1() -> Result<()> { + verify_query(1).await + } + + #[tokio::test] + async fn q2() -> Result<()> { + verify_query(2).await + } + + #[tokio::test] + async fn q3() -> Result<()> { + verify_query(3).await + } + + #[tokio::test] + async fn q4() -> Result<()> { + verify_query(4).await + } + + #[tokio::test] + async fn q5() -> Result<()> { + verify_query(5).await + } + + #[tokio::test] + async fn q6() -> Result<()> { + verify_query(6).await + } + + #[tokio::test] + async fn q7() -> Result<()> { + verify_query(7).await + } + + #[tokio::test] + async fn q8() -> Result<()> { + verify_query(8).await + } + + #[tokio::test] + async fn q9() -> Result<()> { + verify_query(9).await + } + + #[tokio::test] + async fn q10() -> Result<()> { + verify_query(10).await + } + + #[tokio::test] + async fn q11() -> Result<()> { + verify_query(11).await + } + + #[tokio::test] + async fn q12() -> Result<()> { + verify_query(12).await + } + + #[tokio::test] + async fn q13() -> Result<()> { + verify_query(13).await + } + + #[tokio::test] + async fn q14() -> Result<()> { + verify_query(14).await + } + + #[tokio::test] + async fn q15() -> Result<()> { + verify_query(15).await + } + + #[tokio::test] + async fn q16() -> Result<()> { + verify_query(16).await + } + + #[tokio::test] + async fn q17() -> Result<()> { + verify_query(17).await + } + + #[tokio::test] + async fn q18() -> Result<()> { + verify_query(18).await + } + + #[tokio::test] + async fn q19() -> Result<()> { + verify_query(19).await + } + + #[tokio::test] + async fn q20() -> Result<()> { + verify_query(20).await + } + + #[tokio::test] + async fn q21() -> Result<()> { + verify_query(21).await + } + + #[tokio::test] + async fn q22() -> Result<()> { + verify_query(22).await + } + + /// Specialised String representation + fn col_str(column: &ArrayRef, row_index: usize) -> String { + if column.is_null(row_index) { + return "NULL".to_string(); + } + + // Special case ListArray as there is no pretty print support for it yet + if let DataType::FixedSizeList(_, n) = column.data_type() { + let array = column + .as_any() + .downcast_ref::() + .unwrap() + .value(row_index); + + let mut r = Vec::with_capacity(*n as usize); + for i in 0..*n { + r.push(col_str(&array, i as usize)); + } + return format!("[{}]", r.join(",")); + } + + array_value_to_string(column, row_index).unwrap() + } + + /// Converts the results into a 2d array of strings, `result[row][column]` + /// Special cases nulls to NULL for testing + fn result_vec(results: &[RecordBatch]) -> Vec> { + let mut result = vec![]; + for batch in results { + for row_index in 0..batch.num_rows() { + let row_vec = batch + .columns() + .iter() + .map(|column| col_str(column, row_index)) + .collect(); + result.push(row_vec); + } + } + result + } + + fn get_answer_schema(n: usize) -> Schema { + match n { + 1 => Schema::new(vec![ + Field::new("l_returnflag", DataType::Utf8, true), + Field::new("l_linestatus", DataType::Utf8, true), + Field::new("sum_qty", DataType::Float64, true), + Field::new("sum_base_price", DataType::Float64, true), + Field::new("sum_disc_price", DataType::Float64, true), + Field::new("sum_charge", DataType::Float64, true), + Field::new("avg_qty", DataType::Float64, true), + Field::new("avg_price", DataType::Float64, true), + Field::new("avg_disc", DataType::Float64, true), + Field::new("count_order", DataType::UInt64, true), + ]), + + 2 => Schema::new(vec![ + Field::new("s_acctbal", DataType::Float64, true), + Field::new("s_name", DataType::Utf8, true), + Field::new("n_name", DataType::Utf8, true), + Field::new("p_partkey", DataType::Int32, true), + Field::new("p_mfgr", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + Field::new("s_phone", DataType::Utf8, true), + Field::new("s_comment", DataType::Utf8, true), + ]), + + 3 => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int32, true), + Field::new("revenue", DataType::Float64, true), + Field::new("o_orderdat", DataType::Date32(DateUnit::Day), true), + Field::new("o_shippriority", DataType::Int32, true), + ]), + + 4 => Schema::new(vec![ + Field::new("o_orderpriority", DataType::Utf8, true), + Field::new("order_count", DataType::Int32, true), + ]), + + 5 => Schema::new(vec![ + Field::new("n_name", DataType::Utf8, true), + Field::new("revenue", DataType::Float64, true), + ]), + + 6 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]), + + 7 => Schema::new(vec![ + Field::new("supp_nation", DataType::Utf8, true), + Field::new("cust_nation", DataType::Utf8, true), + Field::new("l_year", DataType::Int32, true), + Field::new("revenue", DataType::Float64, true), + ]), + + 8 => Schema::new(vec![ + Field::new("o_year", DataType::Int32, true), + Field::new("mkt_share", DataType::Float64, true), + ]), + + 9 => Schema::new(vec![ + Field::new("nation", DataType::Utf8, true), + Field::new("o_year", DataType::Int32, true), + Field::new("sum_profit", DataType::Float64, true), + ]), + + 10 => Schema::new(vec![ + Field::new("c_custkey", DataType::Int32, true), + Field::new("c_name", DataType::Utf8, true), + Field::new("revenue", DataType::Float64, true), + Field::new("c_acctbal", DataType::Float64, true), + Field::new("n_name", DataType::Utf8, true), + Field::new("c_address", DataType::Utf8, true), + Field::new("c_phone", DataType::Utf8, true), + Field::new("c_comment", DataType::Utf8, true), + ]), + + 11 => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int32, true), + Field::new("value", DataType::Float64, true), + ]), + + 12 => Schema::new(vec![ + Field::new("l_shipmode", DataType::Utf8, true), + Field::new("high_line_count", DataType::Int64, true), + Field::new("low_line_count", DataType::Int64, true), + ]), + + 13 => Schema::new(vec![ + Field::new("c_count", DataType::Int64, true), + Field::new("custdist", DataType::Int64, true), + ]), + + 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), + + 15 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), + + 16 => Schema::new(vec![ + Field::new("p_brand", DataType::Utf8, true), + Field::new("p_type", DataType::Utf8, true), + Field::new("c_phone", DataType::Int32, true), + Field::new("c_comment", DataType::Int32, true), + ]), + + 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), + + 18 => Schema::new(vec![ + Field::new("c_name", DataType::Utf8, true), + Field::new("c_custkey", DataType::Int32, true), + Field::new("o_orderkey", DataType::Int32, true), + Field::new("o_orderdat", DataType::Date32(DateUnit::Day), true), + Field::new("o_totalprice", DataType::Float64, true), + Field::new("sum_l_quantity", DataType::Float64, true), + ]), + + 19 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]), + + 20 => Schema::new(vec![ + Field::new("s_name", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + ]), + + 21 => Schema::new(vec![ + Field::new("s_name", DataType::Utf8, true), + Field::new("numwait", DataType::Int32, true), + ]), + + 22 => Schema::new(vec![ + Field::new("cntrycode", DataType::Int32, true), + Field::new("numcust", DataType::Int32, true), + Field::new("totacctbal", DataType::Float64, true), + ]), + + _ => unimplemented!(), + } + } + + // convert expected schema to all utf8 so columns can be read as strings to be parsed separately + // this is due to the fact that the csv parser cannot handle leading/trailing spaces + fn string_schema(schema: Schema) -> Schema { + Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + Field::name(&field), + DataType::Utf8, + Field::is_nullable(&field), + ) + }) + .collect::>(), + ) + } + + // convert the schema to the same but with all columns set to nullable=true. + // this allows direct schema comparison ignoring nullable. + fn nullable_schema(schema: Arc) -> Schema { + Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + Field::name(&field), + Field::data_type(&field).to_owned(), + true, + ) + }) + .collect::>(), + ) + } + + async fn verify_query(n: usize) -> Result<()> { + if let Ok(path) = env::var("TPCH_DATA") { + // load expected answers from tpch-dbgen + // read csv as all strings, trim and cast to expected type as the csv string + // to value parser does not handle data with leading/trailing spaces + let mut ctx = ExecutionContext::new(); + let schema = string_schema(get_answer_schema(n)); + let options = CsvReadOptions::new() + .schema(&schema) + .delimiter(b'|') + .file_extension(".out"); + let df = ctx.read_csv(&format!("{}/answers/q{}.out", path, n), options)?; + let df = df.select( + get_answer_schema(n) + .fields() + .iter() + .map(|field| { + Expr::Alias( + Box::new(Cast { + expr: Box::new(trim(col(Field::name(&field)))), + data_type: Field::data_type(&field).to_owned(), + }), + Field::name(&field).to_string(), + ) + }) + .collect::>(), + )?; + let expected = df.collect().await?; + + // run the query to compute actual results of the query + let opt = BenchmarkOpt { + query: n, + debug: false, + iterations: 1, + concurrency: 2, + batch_size: 4096, + path: PathBuf::from(path.to_string()), + file_format: "tbl".to_string(), + mem_table: false, + }; + let actual = benchmark(opt).await?; + + // assert schema equality without comparing nullable values + assert_eq!( + nullable_schema(expected[0].schema()), + nullable_schema(actual[0].schema()) + ); + + // convert both datasets to Vec> for simple comparison + let expected_vec = result_vec(&expected); + let actual_vec = result_vec(&actual); + + // basic result comparison + assert_eq!(expected_vec.len(), actual_vec.len()); + + // compare each row. this works as all TPC-H queries have determinisically ordered results + for i in 0..actual_vec.len() { + assert_eq!(expected_vec[i], actual_vec[i]); + } + } else { + println!("TPCH_DATA environment variable not set, skipping test"); + } + + Ok(()) + } +} diff --git a/rust/datafusion/examples/flight_client.rs b/rust/datafusion/examples/flight_client.rs index 13fd394d187..2c2954d5a02 100644 --- a/rust/datafusion/examples/flight_client.rs +++ b/rust/datafusion/examples/flight_client.rs @@ -62,10 +62,13 @@ async fn main() -> Result<(), Box> { // all the remaining stream messages should be dictionary and record batches let mut results = vec![]; + let dictionaries_by_field = vec![None; schema.fields().len()]; while let Some(flight_data) = stream.message().await? { - // the unwrap is infallible and thus safe - let record_batch = - flight_data_to_arrow_batch(&flight_data, schema.clone()).unwrap()?; + let record_batch = flight_data_to_arrow_batch( + &flight_data, + schema.clone(), + &dictionaries_by_field, + )?; results.push(record_batch); } diff --git a/rust/datafusion/examples/flight_server.rs b/rust/datafusion/examples/flight_server.rs index d73405026bb..a9b3d251464 100644 --- a/rust/datafusion/examples/flight_server.rs +++ b/rust/datafusion/examples/flight_server.rs @@ -125,11 +125,14 @@ impl FlightService for FlightServiceImpl { let mut batches: Vec> = results .iter() .flat_map(|batch| { - let flight_data = + let (flight_dictionaries, flight_batch) = arrow_flight::utils::flight_data_from_arrow_batch( batch, &options, ); - flight_data.into_iter().map(Ok) + flight_dictionaries + .into_iter() + .chain(std::iter::once(flight_batch)) + .map(Ok) }) .collect(); diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index a489c117d5e..481011cf758 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -492,7 +492,7 @@ impl ExecutionConfig { pub fn new() -> Self { Self { concurrency: num_cpus::get(), - batch_size: 4096, + batch_size: 32768, query_planner: Arc::new(DefaultQueryPlanner {}), } } diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 7b47aa218dc..db8b86b0137 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -207,6 +207,7 @@ mod tests { avg(col("c12")), sum(col("c12")), count(col("c12")), + count_distinct(col("c12")), ]; let df = df.aggregate(group_expr, aggr_expr)?; @@ -214,7 +215,7 @@ mod tests { let plan = df.to_logical_plan(); // build same plan using SQL API - let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \ + let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12), COUNT(DISTINCT c12) \ FROM aggregate_test_100 \ GROUP BY c1"; let sql_plan = create_plan(sql)?; diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 0ae26a37364..4aee03c4c12 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -570,6 +570,15 @@ pub fn count(expr: Expr) -> Expr { } } +/// Create an expression to represent the count(distinct) aggregate function +pub fn count_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::Count, + distinct: true, + args: vec![expr], + } +} + /// Whether it can be represented as a literal expression pub trait Literal { /// convert the value to a Literal expression diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index f810b0162c4..0d37da6f6a3 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -35,9 +35,9 @@ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos, - count, create_udaf, create_udf, exp, exprlist_to_fields, floor, length, lit, ln, - log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, tan, trim, trunc, - upper, when, Expr, Literal, + count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, + length, lit, ln, log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, + tan, trim, trunc, upper, when, Expr, Literal, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 309b75bc6b1..c8a4804ae6c 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,7 +28,7 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, length, lit, max, min, sum, JoinType, - Partitioning, + array, avg, col, concat, count, create_udf, length, lit, lower, max, min, sum, trim, + upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/integration-testing/Cargo.toml b/rust/integration-testing/Cargo.toml index 1c2687086fb..4c6d5d02e1a 100644 --- a/rust/integration-testing/Cargo.toml +++ b/rust/integration-testing/Cargo.toml @@ -27,20 +27,14 @@ edition = "2018" [dependencies] arrow = { path = "../arrow" } +arrow-flight = { path = "../arrow-flight" } +async-trait = "0.1.41" clap = "2.33" +futures = "0.3" +hex = "0.4" +prost = "0.6" serde = { version = "1.0", features = ["rc"] } serde_derive = "1.0" serde_json = { version = "1.0", features = ["preserve_order"] } -hex = "0.4" - -[[bin]] -name = "arrow-file-to-stream" -path = "src/bin/arrow-file-to-stream.rs" - -[[bin]] -name = "arrow-stream-to-file" -path = "src/bin/arrow-stream-to-file.rs" - -[[bin]] -name = "arrow-json-integration-test" -path = "src/bin/arrow-json-integration-test.rs" +tokio = { version = "0.2", features = ["macros", "rt-core", "rt-threaded"] } +tonic = "0.3" diff --git a/rust/integration-testing/src/bin/arrow-json-integration-test.rs b/rust/integration-testing/src/bin/arrow-json-integration-test.rs index b1bec677cf1..cd89a8edf1d 100644 --- a/rust/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/rust/integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,27 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::fs::File; -use std::io::BufReader; -use std::sync::Arc; use clap::{App, Arg}; -use hex::decode; -use serde_json::Value; -use arrow::array::*; -use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; -use arrow::record_batch::RecordBatch; -use arrow::{ - buffer::Buffer, - buffer::MutableBuffer, - datatypes::ToByteSlice, - util::{bit_util, integration_util::*}, -}; +use arrow::util::integration_util::*; +use arrow_integration_testing::read_json_file; fn main() -> Result<()> { let matches = App::new("rust arrow-json-integration-test") @@ -93,520 +81,6 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> Ok(()) } -fn record_batch_from_json( - schema: &Schema, - json_batch: ArrowJsonBatch, - json_dictionaries: Option<&HashMap>, -) -> Result { - let mut columns = vec![]; - - for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { - let col = array_from_json(field, json_col, json_dictionaries)?; - columns.push(col); - } - - RecordBatch::try_new(Arc::new(schema.clone()), columns) -} - -/// Construct an Arrow array from a partially typed JSON column -fn array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dictionaries: Option<&HashMap>, -) -> Result { - match field.data_type() { - DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), - DataType::Boolean => { - let mut b = BooleanBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_bool().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int8 => { - let mut b = Int8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to get {:?} as int64", - value - )) - })? as i8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int16 => { - let mut b = Int16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int32 - | DataType::Date32(DateUnit::Day) - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - let mut b = Int32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i32), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::Int64 - | DataType::Date64(DateUnit::Millisecond) - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) - | DataType::Interval(IntervalUnit::DayTime) => { - let mut b = Int64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(match value { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => { - s.parse().expect("Unable to parse string as i64") - } - _ => panic!("Unable to parse {:?} as number", value), - }), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::UInt8 => { - let mut b = UInt8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt16 => { - let mut b = UInt16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt32 => { - let mut b = UInt32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt64 => { - let mut b = UInt64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value( - value - .as_str() - .unwrap() - .parse() - .expect("Unable to parse string as u64"), - ), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float32 => { - let mut b = Float32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap() as f32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float64 => { - let mut b = Float64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Binary => { - let mut b = BinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeBinary => { - let mut b = LargeBinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Utf8 => { - let mut b = StringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeUtf8 => { - let mut b = LargeStringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::FixedSizeBinary(len) => { - let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = hex::decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::List(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| v.as_i64().unwrap() as i32) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(ListArray::from(list_data))) - } - DataType::LargeList(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| match v { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => s.parse::().unwrap(), - _ => panic!("64-bit offset must be either string or number"), - }) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(LargeListArray::from(list_data))) - } - DataType::FixedSizeList(child_field, _) => { - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let null_buf = create_null_buf(&json_col); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(FixedSizeListArray::from(list_data))) - } - DataType::Struct(fields) => { - // construct struct with null data - let null_buf = create_null_buf(&json_col); - let mut array_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .null_bit_buffer(null_buf); - - for (field, col) in fields.iter().zip(json_col.children.unwrap()) { - let array = array_from_json(field, col, dictionaries)?; - array_data = array_data.add_child_data(array.data()); - } - - let array = StructArray::from(array_data.build()); - Ok(Arc::new(array)) - } - DataType::Dictionary(key_type, value_type) => { - let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find dict_id for field {:?}", - field - )) - })?; - // find dictionary - let dictionary = dictionaries - .ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find any dictionaries for field {:?}", - field - )) - })? - .get(&dict_id); - match dictionary { - Some(dictionary) => dictionary_array_from_json( - field, json_col, key_type, value_type, dictionary, - ), - None => Err(ArrowError::JsonError(format!( - "Unable to find dictionary for field {:?}", - field - ))), - } - } - t => Err(ArrowError::JsonError(format!( - "data type {:?} not supported", - t - ))), - } -} - -fn dictionary_array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dict_key: &DataType, - dict_value: &DataType, - dictionary: &ArrowJsonDictionaryBatch, -) -> Result { - match dict_key { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => { - let null_buf = create_null_buf(&json_col); - - // build the key data into a buffer, then construct values separately - let key_field = Field::new_dict( - "key", - dict_key.clone(), - field.is_nullable(), - field - .dict_id() - .expect("Dictionary fields must have a dict_id value"), - field - .dict_is_ordered() - .expect("Dictionary fields must have a dict_is_ordered value"), - ); - let keys = array_from_json(&key_field, json_col, None)?; - // note: not enough info on nullability of dictionary - let value_field = Field::new("value", dict_value.clone(), true); - println!("dictionary value type: {:?}", dict_value); - let values = - array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; - - // convert key and value to dictionary data - let dict_data = ArrayData::builder(field.data_type().clone()) - .len(keys.len()) - .add_buffer(keys.data().buffers()[0].clone()) - .null_bit_buffer(null_buf) - .add_child_data(values.data()) - .build(); - - let array = match dict_key { - DataType::Int8 => { - Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef - } - DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), - DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), - DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), - DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), - DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), - DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), - DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), - _ => unreachable!(), - }; - Ok(array) - } - _ => Err(ArrowError::JsonError(format!( - "Dictionary key type {:?} not supported", - dict_key - ))), - } -} - -/// A helper to create a null buffer from a Vec -fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { - let num_bytes = bit_util::ceil(json_col.count, 8); - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); - json_col - .validity - .clone() - .unwrap() - .iter() - .enumerate() - .for_each(|(i, v)| { - let null_slice = null_buf.data_mut(); - if *v != 0 { - bit_util::set_bit(null_slice, i); - } - }); - null_buf.freeze() -} - fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { eprintln!("Converting {} to {}", arrow_name, json_name); @@ -702,43 +176,3 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { Ok(()) } - -struct ArrowFile { - schema: Schema, - // we can evolve this into a concrete Arrow type - // this is temporarily not being read from - _dictionaries: HashMap, - batches: Vec, -} - -fn read_json_file(json_name: &str) -> Result { - let json_file = File::open(json_name)?; - let reader = BufReader::new(json_file); - let arrow_json: Value = serde_json::from_reader(reader).unwrap(); - let schema = Schema::from(&arrow_json["schema"])?; - // read dictionaries - let mut dictionaries = HashMap::new(); - if let Some(dicts) = arrow_json.get("dictionaries") { - for d in dicts - .as_array() - .expect("Unable to get dictionaries as array") - { - let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) - .expect("Unable to get dictionary from JSON"); - // TODO: convert to a concrete Arrow type - dictionaries.insert(json_dict.id, json_dict); - } - } - - let mut batches = vec![]; - for b in arrow_json["batches"].as_array().unwrap() { - let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); - let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; - batches.push(batch); - } - Ok(ArrowFile { - schema, - _dictionaries: dictionaries, - batches, - }) -} diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs new file mode 100644 index 00000000000..17352360f85 --- /dev/null +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_integration_testing::flight_client_scenarios; + +use clap::{App, Arg}; + +type Error = Box; +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result { + let matches = App::new("rust flight-test-integration-client") + .arg(Arg::with_name("host").long("host").takes_value(true)) + .arg(Arg::with_name("port").long("port").takes_value(true)) + .arg(Arg::with_name("path").long("path").takes_value(true)) + .arg( + Arg::with_name("scenario") + .long("scenario") + .takes_value(true), + ) + .get_matches(); + + let host = matches.value_of("host").expect("Host is required"); + let port = matches.value_of("port").expect("Port is required"); + + match matches.value_of("scenario") { + Some("middleware") => { + flight_client_scenarios::middleware::run_scenario(host, port).await? + } + Some("auth:basic_proto") => { + flight_client_scenarios::auth_basic_proto::run_scenario(host, port).await? + } + Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), + None => { + let path = matches + .value_of("path") + .expect("Path is required if scenario is not specified"); + flight_client_scenarios::integration_test::run_scenario(host, port, path) + .await?; + } + } + + Ok(()) +} diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs new file mode 100644 index 00000000000..5ef1253d1ee --- /dev/null +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use clap::{App, Arg}; + +use arrow_integration_testing::flight_server_scenarios; + +type Error = Box; +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result { + let matches = App::new("rust flight-test-integration-server") + .about("Integration testing server for Flight.") + .arg(Arg::with_name("port").long("port").takes_value(true)) + .arg( + Arg::with_name("scenario") + .long("scenario") + .takes_value(true), + ) + .get_matches(); + + let port = matches.value_of("port").unwrap_or("0"); + + match matches.value_of("scenario") { + Some("middleware") => { + flight_server_scenarios::middleware::scenario_setup(port).await? + } + Some("auth:basic_proto") => { + flight_server_scenarios::auth_basic_proto::scenario_setup(port).await? + } + Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), + None => { + flight_server_scenarios::integration_test::scenario_setup(port).await?; + } + } + Ok(()) +} diff --git a/rust/integration-testing/src/flight_client_scenarios.rs b/rust/integration-testing/src/flight_client_scenarios.rs new file mode 100644 index 00000000000..66cced5f4c2 --- /dev/null +++ b/rust/integration-testing/src/flight_client_scenarios.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod auth_basic_proto; +pub mod integration_test; +pub mod middleware; diff --git a/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs b/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs new file mode 100644 index 00000000000..5e8cd467198 --- /dev/null +++ b/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{AUTH_PASSWORD, AUTH_USERNAME}; + +use arrow_flight::{ + flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest, +}; +use futures::{stream, StreamExt}; +use prost::Message; +use tonic::{metadata::MetadataValue, Request, Status}; + +type Error = Box; +type Result = std::result::Result; + +type Client = FlightServiceClient; + +pub async fn run_scenario(host: &str, port: &str) -> Result { + let url = format!("http://{}:{}", host, port); + let mut client = FlightServiceClient::connect(url).await?; + + let action = arrow_flight::Action::default(); + + let resp = client.do_action(Request::new(action.clone())).await; + // This client is unauthenticated and should fail. + match resp { + Err(e) => { + if e.code() != tonic::Code::Unauthenticated { + return Err(Box::new(Status::internal(format!( + "Expected UNAUTHENTICATED but got {:?}", + e + )))); + } + } + Ok(other) => { + return Err(Box::new(Status::internal(format!( + "Expected UNAUTHENTICATED but got {:?}", + other + )))); + } + } + + let token = authenticate(&mut client, AUTH_USERNAME, AUTH_PASSWORD) + .await + .expect("must respond successfully from handshake"); + + let mut request = Request::new(action); + let metadata = request.metadata_mut(); + metadata.insert_bin( + "auth-token-bin", + MetadataValue::from_bytes(token.as_bytes()), + ); + + let resp = client.do_action(request).await?; + let mut resp = resp.into_inner(); + + let r = resp + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + + let body = String::from_utf8(r.body).unwrap(); + assert_eq!(body, AUTH_USERNAME); + + Ok(()) +} + +async fn authenticate( + client: &mut Client, + username: &str, + password: &str, +) -> Result { + let auth = BasicAuth { + username: username.into(), + password: password.into(), + }; + let mut payload = vec![]; + auth.encode(&mut payload)?; + + let req = stream::once(async { + HandshakeRequest { + payload, + ..HandshakeRequest::default() + } + }); + + let rx = client.handshake(Request::new(req)).await?; + let mut rx = rx.into_inner(); + + let r = rx.next().await.expect("must respond from handshake")?; + assert!(rx.next().await.is_none(), "must not respond a second time"); + + Ok(String::from_utf8(r.payload).unwrap()) +} diff --git a/rust/integration-testing/src/flight_client_scenarios/integration_test.rs b/rust/integration-testing/src/flight_client_scenarios/integration_test.rs new file mode 100644 index 00000000000..5705f80f82c --- /dev/null +++ b/rust/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{read_json_file, ArrowFile}; + +use arrow::{ + array::ArrayRef, + datatypes::SchemaRef, + ipc::{self, reader, writer}, + record_batch::RecordBatch, +}; +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, + utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, Ticket, +}; +use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; +use tonic::{Request, Streaming}; + +use std::sync::Arc; + +type Error = Box; +type Result = std::result::Result; + +type Client = FlightServiceClient; + +pub async fn run_scenario(host: &str, port: &str, path: &str) -> Result { + let url = format!("http://{}:{}", host, port); + + let client = FlightServiceClient::connect(url).await?; + + let ArrowFile { + schema, batches, .. + } = read_json_file(path)?; + + let schema = Arc::new(schema); + + let mut descriptor = FlightDescriptor::default(); + descriptor.set_type(DescriptorType::Path); + descriptor.path = vec![path.to_string()]; + + upload_data( + client.clone(), + schema.clone(), + descriptor.clone(), + batches.clone(), + ) + .await?; + verify_data(client, descriptor, schema, &batches).await?; + + Ok(()) +} + +async fn upload_data( + mut client: Client, + schema: SchemaRef, + descriptor: FlightDescriptor, + original_data: Vec, +) -> Result { + let (mut upload_tx, upload_rx) = mpsc::channel(10); + + let options = arrow::ipc::writer::IpcWriteOptions::default(); + let mut schema_flight_data = + arrow_flight::utils::flight_data_from_arrow_schema(&schema, &options); + schema_flight_data.flight_descriptor = Some(descriptor.clone()); + upload_tx.send(schema_flight_data).await?; + + let mut original_data_iter = original_data.iter().enumerate(); + + if let Some((counter, first_batch)) = original_data_iter.next() { + let metadata = counter.to_string().into_bytes(); + // Preload the first batch into the channel before starting the request + send_batch(&mut upload_tx, &metadata, first_batch, &options).await?; + + let outer = client.do_put(Request::new(upload_rx)).await?; + let mut inner = outer.into_inner(); + + let r = inner + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + + // Stream the rest of the batches + for (counter, batch) in original_data_iter { + let metadata = counter.to_string().into_bytes(); + send_batch(&mut upload_tx, &metadata, batch, &options).await?; + + let r = inner + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + } + } else { + drop(upload_tx); + client.do_put(Request::new(upload_rx)).await?; + } + + Ok(()) +} + +async fn send_batch( + upload_tx: &mut mpsc::Sender, + metadata: &[u8], + batch: &RecordBatch, + options: &writer::IpcWriteOptions, +) -> Result { + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); + + upload_tx + .send_all(&mut stream::iter(dictionary_flight_data).map(Ok)) + .await?; + + // Only the record batch's FlightData gets app_metadata + batch_flight_data.app_metadata = metadata.to_vec(); + upload_tx.send(batch_flight_data).await?; + Ok(()) +} + +async fn verify_data( + mut client: Client, + descriptor: FlightDescriptor, + expected_schema: SchemaRef, + expected_data: &[RecordBatch], +) -> Result { + let resp = client.get_flight_info(Request::new(descriptor)).await?; + let info = resp.into_inner(); + + assert!( + !info.endpoint.is_empty(), + "No endpoints returned from Flight server", + ); + for endpoint in info.endpoint { + let ticket = endpoint + .ticket + .expect("No ticket returned from Flight server"); + + assert!( + !endpoint.location.is_empty(), + "No locations returned from Flight server", + ); + for location in endpoint.location { + consume_flight_location( + location, + ticket.clone(), + &expected_data, + expected_schema.clone(), + ) + .await?; + } + } + + Ok(()) +} + +async fn consume_flight_location( + location: Location, + ticket: Ticket, + expected_data: &[RecordBatch], + schema: SchemaRef, +) -> Result { + let mut location = location; + // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs + // don't recognize this as valid. + location.uri = location.uri.replace("grpc+tcp://", "grpc://"); + + let mut client = FlightServiceClient::connect(location.uri).await?; + let resp = client.do_get(ticket).await?; + let mut resp = resp.into_inner(); + + // We already have the schema from the FlightInfo, but the server sends it again as the + // first FlightData. Ignore this one. + let _schema_again = resp.next().await.unwrap(); + + let mut dictionaries_by_field = vec![None; schema.fields().len()]; + + for (counter, expected_batch) in expected_data.iter().enumerate() { + let data = receive_batch_flight_data( + &mut resp, + schema.clone(), + &mut dictionaries_by_field, + ) + .await + .unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + }); + + let metadata = counter.to_string().into_bytes(); + assert_eq!(metadata, data.app_metadata); + + let actual_batch = + flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_field) + .expect("Unable to convert flight data to Arrow batch"); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + let schema = expected_batch.schema(); + for i in 0..expected_batch.num_columns() { + let field = schema.field(i); + let field_name = field.name(); + + let expected_data = expected_batch.column(i).data(); + let actual_data = actual_batch.column(i).data(); + + assert_eq!(expected_data, actual_data, "Data for field {}", field_name); + } + } + + assert!( + resp.next().await.is_none(), + "Got more batches than the expected: {}", + expected_data.len(), + ); + + Ok(()) +} + +async fn receive_batch_flight_data( + resp: &mut Streaming, + schema: SchemaRef, + dictionaries_by_field: &mut [Option], +) -> Option { + let mut data = resp.next().await?.ok()?; + let mut message = arrow::ipc::root_as_message(&data.data_header[..]) + .expect("Error parsing first message"); + + while message.header_type() == ipc::MessageHeader::DictionaryBatch { + reader::read_dictionary( + &data.data_body, + message + .header_as_dictionary_batch() + .expect("Error parsing dictionary"), + &schema, + dictionaries_by_field, + ) + .expect("Error reading dictionary"); + + data = resp.next().await?.ok()?; + message = arrow::ipc::root_as_message(&data.data_header[..]) + .expect("Error parsing message"); + } + + Some(data) +} diff --git a/rust/integration-testing/src/flight_client_scenarios/middleware.rs b/rust/integration-testing/src/flight_client_scenarios/middleware.rs new file mode 100644 index 00000000000..607eab1018a --- /dev/null +++ b/rust/integration-testing/src/flight_client_scenarios/middleware.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, + FlightDescriptor, +}; +use tonic::{Request, Status}; + +type Error = Box; +type Result = std::result::Result; + +pub async fn run_scenario(host: &str, port: &str) -> Result { + let url = format!("http://{}:{}", host, port); + let conn = tonic::transport::Endpoint::new(url)?.connect().await?; + let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor); + + let mut descriptor = FlightDescriptor::default(); + descriptor.set_type(DescriptorType::Cmd); + descriptor.cmd = b"".to_vec(); + + // This call is expected to fail. + match client + .get_flight_info(Request::new(descriptor.clone())) + .await + { + Ok(_) => return Err(Box::new(Status::internal("Expected call to fail"))), + Err(e) => { + let headers = e.metadata(); + let middleware_header = headers.get("x-middleware"); + let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or(""); + + if value != "expected value" { + let msg = format!( + "On failing call: Expected to receive header 'x-middleware: expected value', \ + but instead got: '{}'", + value + ); + return Err(Box::new(Status::internal(msg))); + } + } + } + + // This call should succeed + descriptor.cmd = b"success".to_vec(); + let resp = client.get_flight_info(Request::new(descriptor)).await?; + + let headers = resp.metadata(); + let middleware_header = headers.get("x-middleware"); + let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or(""); + + if value != "expected value" { + let msg = format!( + "On success call: Expected to receive header 'x-middleware: expected value', \ + but instead got: '{}'", + value + ); + return Err(Box::new(Status::internal(msg))); + } + + Ok(()) +} + +fn middleware_interceptor(mut req: Request<()>) -> Result, Status> { + let metadata = req.metadata_mut(); + metadata.insert("x-middleware", "expected value".parse().unwrap()); + Ok(req) +} diff --git a/rust/integration-testing/src/flight_server_scenarios.rs b/rust/integration-testing/src/flight_server_scenarios.rs new file mode 100644 index 00000000000..3d99e535f76 --- /dev/null +++ b/rust/integration-testing/src/flight_server_scenarios.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::net::SocketAddr; + +use arrow_flight::{FlightEndpoint, Location, Ticket}; +use tokio::net::TcpListener; + +pub mod auth_basic_proto; +pub mod integration_test; +pub mod middleware; + +type Error = Box; +type Result = std::result::Result; + +pub async fn listen_on(port: &str) -> Result<(TcpListener, SocketAddr)> { + let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?; + + let listener = TcpListener::bind(addr).await?; + let addr = listener.local_addr()?; + println!("Server listening on localhost:{}", addr.port()); + + Ok((listener, addr)) +} + +pub fn endpoint(ticket: &str, location_uri: impl Into) -> FlightEndpoint { + FlightEndpoint { + ticket: Some(Ticket { + ticket: ticket.as_bytes().to_vec(), + }), + location: vec![Location { + uri: location_uri.into(), + }], + } +} diff --git a/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs new file mode 100644 index 00000000000..355209f2efb --- /dev/null +++ b/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::pin::Pin; +use std::sync::Arc; + +use arrow_flight::{ + flight_service_server::FlightService, flight_service_server::FlightServiceServer, + Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, + FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, +}; +use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; +use tokio::sync::Mutex; +use tonic::{ + metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming, +}; + +type TonicStream = Pin + Send + Sync + 'static>>; + +type Error = Box; +type Result = std::result::Result; + +use prost::Message; + +use crate::{AUTH_PASSWORD, AUTH_USERNAME}; + +pub async fn scenario_setup(port: &str) -> Result { + let (mut listener, _) = super::listen_on(port).await?; + + let service = AuthBasicProtoScenarioImpl { + username: AUTH_USERNAME.into(), + password: AUTH_PASSWORD.into(), + peer_identity: Arc::new(Mutex::new(None)), + }; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + Ok(()) +} + +#[derive(Clone)] +pub struct AuthBasicProtoScenarioImpl { + username: Arc, + password: Arc, + peer_identity: Arc>>, +} + +impl AuthBasicProtoScenarioImpl { + async fn check_auth( + &self, + metadata: &MetadataMap, + ) -> Result { + let token = metadata + .get_bin("auth-token-bin") + .and_then(|v| v.to_bytes().ok()) + .and_then(|b| String::from_utf8(b.to_vec()).ok()); + self.is_valid(token).await + } + + async fn is_valid( + &self, + token: Option, + ) -> Result { + match token { + Some(t) if t == *self.username => Ok(GrpcServerCallContext { + peer_identity: self.username.to_string(), + }), + _ => Err(Status::unauthenticated("Invalid token")), + } + } +} + +struct GrpcServerCallContext { + peer_identity: String, +} + +impl GrpcServerCallContext { + pub fn peer_identity(&self) -> &str { + &self.peer_identity + } +} + +#[tonic::async_trait] +impl FlightService for AuthBasicProtoScenarioImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn handshake( + &self, + request: Request>, + ) -> Result, Status> { + let (tx, rx) = mpsc::channel(10); + + tokio::spawn({ + let username = self.username.clone(); + let password = self.password.clone(); + + async move { + let requests = request.into_inner(); + + requests + .for_each(move |req| { + let mut tx = tx.clone(); + let req = req.expect("Error reading handshake request"); + let HandshakeRequest { payload, .. } = req; + + let auth = BasicAuth::decode(&*payload) + .expect("Error parsing handshake request"); + + let resp = if *auth.username == *username + && *auth.password == *password + { + Ok(HandshakeResponse { + payload: username.as_bytes().to_vec(), + ..HandshakeResponse::default() + }) + } else { + Err(Status::unauthenticated(format!( + "Don't know user {}", + auth.username + ))) + }; + + async move { + tx.send(resp) + .await + .expect("Error sending handshake response"); + } + }) + .await; + } + }); + + Ok(Response::new(Box::pin(rx))) + } + + async fn list_flights( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_action( + &self, + request: Request, + ) -> Result, Status> { + let flight_context = self.check_auth(request.metadata()).await?; + // Respond with the authenticated username. + let buf = flight_context.peer_identity().as_bytes().to_vec(); + let result = arrow_flight::Result { body: buf }; + let output = futures::stream::once(async { Ok(result) }); + Ok(Response::new(Box::pin(output) as Self::DoActionStream)) + } + + async fn list_actions( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + request: Request>, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } +} diff --git a/rust/integration-testing/src/flight_server_scenarios/integration_test.rs b/rust/integration-testing/src/flight_server_scenarios/integration_test.rs new file mode 100644 index 00000000000..a555b2efad9 --- /dev/null +++ b/rust/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -0,0 +1,385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::convert::TryFrom; +use std::pin::Pin; +use std::sync::Arc; + +use arrow::{ + array::ArrayRef, + datatypes::Schema, + datatypes::SchemaRef, + ipc::{self, reader}, + record_batch::RecordBatch, +}; +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_server::FlightService, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, + FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, PutResult, SchemaResult, Ticket, +}; +use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; +use tokio::sync::Mutex; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + +type TonicStream = Pin + Send + Sync + 'static>>; + +type Error = Box; +type Result = std::result::Result; + +pub async fn scenario_setup(port: &str) -> Result { + let (mut listener, addr) = super::listen_on(port).await?; + + let service = FlightServiceImpl { + server_location: format!("grpc+tcp://{}", addr), + ..Default::default() + }; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + + Ok(()) +} + +#[derive(Debug, Clone)] +struct IntegrationDataset { + schema: Schema, + chunks: Vec, +} + +#[derive(Clone, Default)] +pub struct FlightServiceImpl { + server_location: String, + uploaded_chunks: Arc>>, +} + +impl FlightServiceImpl { + fn endpoint_from_path(&self, path: &str) -> FlightEndpoint { + super::endpoint(path, &self.server_location) + } +} + +#[tonic::async_trait] +impl FlightService for FlightServiceImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + let ticket = request.into_inner(); + + let key = String::from_utf8(ticket.ticket.to_vec()) + .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {:?}", e)))?; + + let uploaded_chunks = self.uploaded_chunks.lock().await; + + let flight = uploaded_chunks.get(&key).ok_or_else(|| { + Status::not_found(format!("Could not find flight. {}", key)) + })?; + + let options = arrow::ipc::writer::IpcWriteOptions::default(); + + let schema = std::iter::once({ + Ok(arrow_flight::utils::flight_data_from_arrow_schema( + &flight.schema, + &options, + )) + }); + + let batches = flight + .chunks + .iter() + .enumerate() + .flat_map(|(counter, batch)| { + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); + + // Only the record batch's FlightData gets app_metadata + let metadata = counter.to_string().into_bytes(); + batch_flight_data.app_metadata = metadata; + + dictionary_flight_data + .into_iter() + .chain(std::iter::once(batch_flight_data)) + .map(Ok) + }); + + let output = futures::stream::iter(schema.chain(batches).collect::>()); + + Ok(Response::new(Box::pin(output) as Self::DoGetStream)) + } + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + let descriptor = request.into_inner(); + + match descriptor.r#type { + t if t == DescriptorType::Path as i32 => { + let path = &descriptor.path; + if path.is_empty() { + return Err(Status::invalid_argument("Invalid path")); + } + + let uploaded_chunks = self.uploaded_chunks.lock().await; + let flight = uploaded_chunks.get(&path[0]).ok_or_else(|| { + Status::not_found(format!("Could not find flight. {}", path[0])) + })?; + + let endpoint = self.endpoint_from_path(&path[0]); + + let total_records: usize = + flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); + + let options = arrow::ipc::writer::IpcWriteOptions::default(); + let schema = arrow_flight::utils::ipc_message_from_arrow_schema( + &flight.schema, + &options, + ) + .expect( + "Could not generate schema bytes from schema stored by a DoPut; \ + this should be impossible", + ); + + let info = FlightInfo { + schema, + flight_descriptor: Some(descriptor.clone()), + endpoint: vec![endpoint], + total_records: total_records as i64, + total_bytes: -1, + }; + + Ok(Response::new(info)) + } + other => Err(Status::unimplemented(format!("Request type: {}", other))), + } + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + let mut input_stream = request.into_inner(); + let flight_data = input_stream + .message() + .await? + .ok_or_else(|| Status::invalid_argument("Must send some FlightData"))?; + + let descriptor = flight_data + .flight_descriptor + .clone() + .ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?; + + if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() + { + return Err(Status::invalid_argument("Must specify a path")); + } + + let key = descriptor.path[0].clone(); + + let schema = Schema::try_from(&flight_data) + .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?; + let schema_ref = Arc::new(schema.clone()); + + let (response_tx, response_rx) = mpsc::channel(10); + + let uploaded_chunks = self.uploaded_chunks.clone(); + + tokio::spawn(async { + let mut error_tx = response_tx.clone(); + if let Err(e) = save_uploaded_chunks( + uploaded_chunks, + schema_ref, + input_stream, + response_tx, + schema, + key, + ) + .await + { + error_tx.send(Err(e)).await.expect("Error sending error") + } + }); + + Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream)) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +async fn send_app_metadata( + tx: &mut mpsc::Sender>, + app_metadata: &[u8], +) -> Result<(), Status> { + tx.send(Ok(PutResult { + app_metadata: app_metadata.to_vec(), + })) + .await + .map_err(|e| Status::internal(format!("Could not send PutResult: {:?}", e))) +} + +async fn record_batch_from_message( + message: ipc::Message<'_>, + data_body: &[u8], + schema_ref: SchemaRef, + dictionaries_by_field: &[Option], +) -> Result { + let ipc_batch = message.header_as_record_batch().ok_or_else(|| { + Status::internal("Could not parse message header as record batch") + })?; + + let arrow_batch_result = reader::read_record_batch( + data_body, + ipc_batch, + schema_ref, + &dictionaries_by_field, + ); + + arrow_batch_result.map_err(|e| { + Status::internal(format!("Could not convert to RecordBatch: {:?}", e)) + }) +} + +async fn dictionary_from_message( + message: ipc::Message<'_>, + data_body: &[u8], + schema_ref: SchemaRef, + dictionaries_by_field: &mut [Option], +) -> Result<(), Status> { + let ipc_batch = message.header_as_dictionary_batch().ok_or_else(|| { + Status::internal("Could not parse message header as dictionary batch") + })?; + + let dictionary_batch_result = + reader::read_dictionary(data_body, ipc_batch, &schema_ref, dictionaries_by_field); + dictionary_batch_result.map_err(|e| { + Status::internal(format!("Could not convert to Dictionary: {:?}", e)) + }) +} + +async fn save_uploaded_chunks( + uploaded_chunks: Arc>>, + schema_ref: Arc, + mut input_stream: Streaming, + mut response_tx: mpsc::Sender>, + schema: Schema, + key: String, +) -> Result<(), Status> { + let mut chunks = vec![]; + let mut uploaded_chunks = uploaded_chunks.lock().await; + + let mut dictionaries_by_field = vec![None; schema_ref.fields().len()]; + + while let Some(Ok(data)) = input_stream.next().await { + let message = arrow::ipc::root_as_message(&data.data_header[..]) + .map_err(|e| Status::internal(format!("Could not parse message: {:?}", e)))?; + + match message.header_type() { + ipc::MessageHeader::Schema => { + return Err(Status::internal( + "Not expecting a schema when messages are read", + )) + } + ipc::MessageHeader::RecordBatch => { + send_app_metadata(&mut response_tx, &data.app_metadata).await?; + + let batch = record_batch_from_message( + message, + &data.data_body, + schema_ref.clone(), + &dictionaries_by_field, + ) + .await?; + + chunks.push(batch); + } + ipc::MessageHeader::DictionaryBatch => { + dictionary_from_message( + message, + &data.data_body, + schema_ref.clone(), + &mut dictionaries_by_field, + ) + .await?; + } + t => { + return Err(Status::internal(format!( + "Reading types other than record batches not yet supported, \ + unable to read {:?}", + t + ))); + } + } + } + + let dataset = IntegrationDataset { schema, chunks }; + uploaded_chunks.insert(key, dataset); + + Ok(()) +} diff --git a/rust/integration-testing/src/flight_server_scenarios/middleware.rs b/rust/integration-testing/src/flight_server_scenarios/middleware.rs new file mode 100644 index 00000000000..12421bc8928 --- /dev/null +++ b/rust/integration-testing/src/flight_server_scenarios/middleware.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::pin::Pin; + +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_server::FlightService, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, + FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, + PutResult, SchemaResult, Ticket, +}; +use futures::Stream; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + +type TonicStream = Pin + Send + Sync + 'static>>; + +type Error = Box; +type Result = std::result::Result; + +pub async fn scenario_setup(port: &str) -> Result { + let (mut listener, _) = super::listen_on(port).await?; + + let service = MiddlewareScenarioImpl {}; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + Ok(()) +} + +#[derive(Clone, Default)] +pub struct MiddlewareScenarioImpl {} + +#[tonic::async_trait] +impl FlightService for MiddlewareScenarioImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + let middleware_header = request.metadata().get("x-middleware").cloned(); + + let descriptor = request.into_inner(); + + if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd == b"success" + { + // Return a fake location - the test doesn't read it + let endpoint = super::endpoint("foo", "grpc+tcp://localhost:10010"); + + let info = FlightInfo { + flight_descriptor: Some(descriptor), + endpoint: vec![endpoint], + ..Default::default() + }; + + let mut response = Response::new(info); + if let Some(value) = middleware_header { + response.metadata_mut().insert("x-middleware", value); + } + + return Ok(response); + } + + let mut status = Status::unknown("Unknown"); + if let Some(value) = middleware_header { + status.metadata_mut().insert("x-middleware", value); + } + + Err(status) + } + + async fn do_put( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/lib.rs index 596017a79bd..93bc8d55626 100644 --- a/rust/integration-testing/src/lib.rs +++ b/rust/integration-testing/src/lib.rs @@ -16,3 +16,586 @@ // under the License. //! Common code used in the integration test binaries + +use hex::decode; +use serde_json::Value; + +use arrow::util::integration_util::ArrowJsonBatch; + +use arrow::array::*; +use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; +use arrow::error::{ArrowError, Result}; +use arrow::record_batch::RecordBatch; +use arrow::{ + buffer::Buffer, + buffer::MutableBuffer, + datatypes::ToByteSlice, + util::{bit_util, integration_util::*}, +}; + +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; + +/// The expected username for the basic auth integration test. +pub const AUTH_USERNAME: &str = "arrow"; +/// The expected password for the basic auth integration test. +pub const AUTH_PASSWORD: &str = "flight"; + +pub mod flight_client_scenarios; +pub mod flight_server_scenarios; + +pub struct ArrowFile { + pub schema: Schema, + // we can evolve this into a concrete Arrow type + // this is temporarily not being read from + pub _dictionaries: HashMap, + pub batches: Vec, +} + +pub fn read_json_file(json_name: &str) -> Result { + let json_file = File::open(json_name)?; + let reader = BufReader::new(json_file); + let arrow_json: Value = serde_json::from_reader(reader).unwrap(); + let schema = Schema::from(&arrow_json["schema"])?; + // read dictionaries + let mut dictionaries = HashMap::new(); + if let Some(dicts) = arrow_json.get("dictionaries") { + for d in dicts + .as_array() + .expect("Unable to get dictionaries as array") + { + let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) + .expect("Unable to get dictionary from JSON"); + // TODO: convert to a concrete Arrow type + dictionaries.insert(json_dict.id, json_dict); + } + } + + let mut batches = vec![]; + for b in arrow_json["batches"].as_array().unwrap() { + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; + batches.push(batch); + } + Ok(ArrowFile { + schema, + _dictionaries: dictionaries, + batches, + }) +} + +fn record_batch_from_json( + schema: &Schema, + json_batch: ArrowJsonBatch, + json_dictionaries: Option<&HashMap>, +) -> Result { + let mut columns = vec![]; + + for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { + let col = array_from_json(field, json_col, json_dictionaries)?; + columns.push(col); + } + + RecordBatch::try_new(Arc::new(schema.clone()), columns) +} + +/// Construct an Arrow array from a partially typed JSON column +fn array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dictionaries: Option<&HashMap>, +) -> Result { + match field.data_type() { + DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), + DataType::Boolean => { + let mut b = BooleanBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_bool().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int8 => { + let mut b = Int8Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to get {:?} as int64", + value + )) + })? as i8), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int16 => { + let mut b = Int16Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i16), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int32 + | DataType::Date32(DateUnit::Day) + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + let mut b = Int32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i32), + _ => b.append_null(), + }?; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::Int64 + | DataType::Date64(DateUnit::Millisecond) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => { + let mut b = Int64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(match value { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => { + s.parse().expect("Unable to parse string as i64") + } + _ => panic!("Unable to parse {:?} as number", value), + }), + _ => b.append_null(), + }?; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::UInt8 => { + let mut b = UInt8Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u8), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt16 => { + let mut b = UInt16Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u16), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt32 => { + let mut b = UInt32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u32), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt64 => { + let mut b = UInt64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value( + value + .as_str() + .unwrap() + .parse() + .expect("Unable to parse string as u64"), + ), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Float32 => { + let mut b = Float32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap() as f32), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Float64 => { + let mut b = Float64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Binary => { + let mut b = BinaryBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeBinary => { + let mut b = LargeBinaryBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Utf8 => { + let mut b = StringBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeUtf8 => { + let mut b = LargeStringBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::FixedSizeBinary(len) => { + let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = hex::decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::List(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| v.as_i64().unwrap() as i32) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(ListArray::from(list_data))) + } + DataType::LargeList(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| match v { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => s.parse::().unwrap(), + _ => panic!("64-bit offset must be either string or number"), + }) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(LargeListArray::from(list_data))) + } + DataType::FixedSizeList(child_field, _) => { + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let null_buf = create_null_buf(&json_col); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(FixedSizeListArray::from(list_data))) + } + DataType::Struct(fields) => { + // construct struct with null data + let null_buf = create_null_buf(&json_col); + let mut array_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .null_bit_buffer(null_buf); + + for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + let array = array_from_json(field, col, dictionaries)?; + array_data = array_data.add_child_data(array.data()); + } + + let array = StructArray::from(array_data.build()); + Ok(Arc::new(array)) + } + DataType::Dictionary(key_type, value_type) => { + let dict_id = field.dict_id().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find dict_id for field {:?}", + field + )) + })?; + // find dictionary + let dictionary = dictionaries + .ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find any dictionaries for field {:?}", + field + )) + })? + .get(&dict_id); + match dictionary { + Some(dictionary) => dictionary_array_from_json( + field, json_col, key_type, value_type, dictionary, + ), + None => Err(ArrowError::JsonError(format!( + "Unable to find dictionary for field {:?}", + field + ))), + } + } + t => Err(ArrowError::JsonError(format!( + "data type {:?} not supported", + t + ))), + } +} + +fn dictionary_array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dict_key: &DataType, + dict_value: &DataType, + dictionary: &ArrowJsonDictionaryBatch, +) -> Result { + match dict_key { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + let null_buf = create_null_buf(&json_col); + + // build the key data into a buffer, then construct values separately + let key_field = Field::new_dict( + "key", + dict_key.clone(), + field.is_nullable(), + field + .dict_id() + .expect("Dictionary fields must have a dict_id value"), + field + .dict_is_ordered() + .expect("Dictionary fields must have a dict_is_ordered value"), + ); + let keys = array_from_json(&key_field, json_col, None)?; + // note: not enough info on nullability of dictionary + let value_field = Field::new("value", dict_value.clone(), true); + println!("dictionary value type: {:?}", dict_value); + let values = + array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; + + // convert key and value to dictionary data + let dict_data = ArrayData::builder(field.data_type().clone()) + .len(keys.len()) + .add_buffer(keys.data().buffers()[0].clone()) + .null_bit_buffer(null_buf) + .add_child_data(values.data()) + .build(); + + let array = match dict_key { + DataType::Int8 => { + Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef + } + DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), + DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), + DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), + DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), + DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), + DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), + DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), + _ => unreachable!(), + }; + Ok(array) + } + _ => Err(ArrowError::JsonError(format!( + "Dictionary key type {:?} not supported", + dict_key + ))), + } +} + +/// A helper to create a null buffer from a Vec +fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { + let num_bytes = bit_util::ceil(json_col.count, 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + json_col + .validity + .clone() + .unwrap() + .iter() + .enumerate() + .for_each(|(i, v)| { + let null_slice = null_buf.data_mut(); + if *v != 0 { + bit_util::set_bit(null_slice, i); + } + }); + null_buf.freeze() +}