From b07489bf7086bb3d857f3c7fd9dc1d634d1b3c5b Mon Sep 17 00:00:00 2001 From: Mike Seddon Date: Mon, 28 Dec 2020 10:05:28 -0700 Subject: [PATCH 01/12] ARROW-10712: [Rust] [DataFusion] Add tests to TPC-H benchmarks This PR adds the ability to load and parse the expected query answers included in the https://github.com/databricks/tpch-dbgen/tree/master/answers repository - which users have to clone anyway to generate the TPC-H data. Currently DataFusion does not support Decimal types which all the numeric values in TPC-H are so there are the expected precision errors in the current results. These tests are still useful as they show some interesting results already such as non-deterministic query 5 results. @andygrove Closes #9015 from seddonm1/test-tpch Authored-by: Mike Seddon Signed-off-by: Andy Grove --- rust/benchmarks/README.md | 3 +- rust/benchmarks/src/bin/tpch.rs | 439 ++++++++++++++++++++++++++++++-- rust/datafusion/src/prelude.rs | 4 +- 3 files changed, 424 insertions(+), 22 deletions(-) diff --git a/rust/benchmarks/README.md b/rust/benchmarks/README.md index 9bff3e2d8ee..2ae035b9fc4 100644 --- a/rust/benchmarks/README.md +++ b/rust/benchmarks/README.md @@ -37,6 +37,7 @@ clone the repository and build the source code. git clone git@github.com:databricks/tpch-dbgen.git cd tpch-dbgen make +export TPCH_DATA=$(pwd) ``` Data can now be generated with the following command. Note that `-s 1` means use Scale Factor 1 or ~1 GB of @@ -63,7 +64,7 @@ This utility does not yet provide support for changing the number of partitions option is to use the following Docker image to perform the conversion from `tbl` files to CSV or Parquet. ```bash -docker run -it ballistacompute/spark-benchmarks:0.4.0-SNAPSHOT +docker run -it ballistacompute/spark-benchmarks:0.4.0-SNAPSHOT -h, --help Show help message Subcommand: convert-tpch diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index 769668c958b..eb789baa7a2 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -108,15 +108,14 @@ const TABLES: &[&str] = &[ #[tokio::main] async fn main() -> Result<()> { + env_logger::init(); match TpchOpt::from_args() { - TpchOpt::Benchmark(opt) => benchmark(opt).await, + TpchOpt::Benchmark(opt) => benchmark(opt).await.map(|_| ()), TpchOpt::Convert(opt) => convert_tbl(opt).await, } } -async fn benchmark(opt: BenchmarkOpt) -> Result<()> { - env_logger::init(); - +async fn benchmark(opt: BenchmarkOpt) -> Result> { println!("Running benchmarks with the following options: {:?}", opt); let config = ExecutionConfig::new() .with_concurrency(opt.concurrency) @@ -146,10 +145,11 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<()> { let mut millis = vec![]; // run benchmark + let mut result: Vec = Vec::with_capacity(1); for i in 0..opt.iterations { let start = Instant::now(); let plan = create_logical_plan(&mut ctx, opt.query)?; - execute_query(&mut ctx, &plan, opt.debug).await?; + result = execute_query(&mut ctx, &plan, opt.debug).await?; let elapsed = start.elapsed().as_secs_f64() * 1000.0; millis.push(elapsed as f64); println!("Query {} iteration {} took {:.1} ms", opt.query, i, elapsed); @@ -158,7 +158,7 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<()> { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {} avg time: {:.2} ms", opt.query, avg); - Ok(()) + Ok(result) } fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result { @@ -994,7 +994,7 @@ async fn execute_query( ctx: &mut ExecutionContext, plan: &LogicalPlan, debug: bool, -) -> Result<()> { +) -> Result> { if debug { println!("Logical plan:\n{:?}", plan); } @@ -1007,12 +1007,11 @@ async fn execute_query( if debug { pretty::print_batches(&result)?; } - Ok(()) + Ok(result) } async fn convert_tbl(opt: ConvertOpt) -> Result<()> { let output_root_path = Path::new(&opt.output_path); - for table in TABLES { let start = Instant::now(); let schema = get_schema(table); @@ -1088,13 +1087,14 @@ fn get_table( table_format: &str, ) -> Result> { match table_format { - // dbgen creates .tbl ('|' delimited) files + // dbgen creates .tbl ('|' delimited) files without header "tbl" => { let path = format!("{}/{}.tbl", path, table); let schema = get_schema(table); let options = CsvReadOptions::new() .schema(&schema) .delimiter(b'|') + .has_header(false) .file_extension(".tbl"); Ok(Box::new(CsvFile::try_new(&path, options)?)) @@ -1130,7 +1130,7 @@ fn get_schema(table: &str) -> Schema { Field::new("p_type", DataType::Utf8, false), Field::new("p_size", DataType::Int32, false), Field::new("p_container", DataType::Utf8, false), - Field::new("p_retailprice", DataType::Float64, false), // decimal + Field::new("p_retailprice", DataType::Float64, false), Field::new("p_comment", DataType::Utf8, false), ]), @@ -1140,7 +1140,7 @@ fn get_schema(table: &str) -> Schema { Field::new("s_address", DataType::Utf8, false), Field::new("s_nationkey", DataType::Int32, false), Field::new("s_phone", DataType::Utf8, false), - Field::new("s_acctbal", DataType::Float64, false), // decimal + Field::new("s_acctbal", DataType::Float64, false), Field::new("s_comment", DataType::Utf8, false), ]), @@ -1148,7 +1148,7 @@ fn get_schema(table: &str) -> Schema { Field::new("ps_partkey", DataType::Int32, false), Field::new("ps_suppkey", DataType::Int32, false), Field::new("ps_availqty", DataType::Int32, false), - Field::new("ps_supplycost", DataType::Float64, false), // decimal + Field::new("ps_supplycost", DataType::Float64, false), Field::new("ps_comment", DataType::Utf8, false), ]), @@ -1158,7 +1158,7 @@ fn get_schema(table: &str) -> Schema { Field::new("c_address", DataType::Utf8, false), Field::new("c_nationkey", DataType::Int32, false), Field::new("c_phone", DataType::Utf8, false), - Field::new("c_acctbal", DataType::Float64, false), // decimal + Field::new("c_acctbal", DataType::Float64, false), Field::new("c_mktsegment", DataType::Utf8, false), Field::new("c_comment", DataType::Utf8, false), ]), @@ -1167,7 +1167,7 @@ fn get_schema(table: &str) -> Schema { Field::new("o_orderkey", DataType::Int32, false), Field::new("o_custkey", DataType::Int32, false), Field::new("o_orderstatus", DataType::Utf8, false), - Field::new("o_totalprice", DataType::Float64, false), // decimal + Field::new("o_totalprice", DataType::Float64, false), Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false), Field::new("o_orderpriority", DataType::Utf8, false), Field::new("o_clerk", DataType::Utf8, false), @@ -1180,10 +1180,10 @@ fn get_schema(table: &str) -> Schema { Field::new("l_partkey", DataType::Int32, false), Field::new("l_suppkey", DataType::Int32, false), Field::new("l_linenumber", DataType::Int32, false), - Field::new("l_quantity", DataType::Float64, false), // decimal - Field::new("l_extendedprice", DataType::Float64, false), // decimal - Field::new("l_discount", DataType::Float64, false), // decimal - Field::new("l_tax", DataType::Float64, false), // decimal + Field::new("l_quantity", DataType::Float64, false), + Field::new("l_extendedprice", DataType::Float64, false), + Field::new("l_discount", DataType::Float64, false), + Field::new("l_tax", DataType::Float64, false), Field::new("l_returnflag", DataType::Utf8, false), Field::new("l_linestatus", DataType::Utf8, false), Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false), @@ -1210,3 +1210,404 @@ fn get_schema(table: &str) -> Schema { _ => unimplemented!(), } } + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + use std::sync::Arc; + + use arrow::array::*; + use arrow::record_batch::RecordBatch; + use arrow::util::display::array_value_to_string; + + use datafusion::logical_plan::Expr; + use datafusion::logical_plan::Expr::Cast; + + #[tokio::test] + async fn q1() -> Result<()> { + verify_query(1).await + } + + #[tokio::test] + async fn q2() -> Result<()> { + verify_query(2).await + } + + #[tokio::test] + async fn q3() -> Result<()> { + verify_query(3).await + } + + #[tokio::test] + async fn q4() -> Result<()> { + verify_query(4).await + } + + #[tokio::test] + async fn q5() -> Result<()> { + verify_query(5).await + } + + #[tokio::test] + async fn q6() -> Result<()> { + verify_query(6).await + } + + #[tokio::test] + async fn q7() -> Result<()> { + verify_query(7).await + } + + #[tokio::test] + async fn q8() -> Result<()> { + verify_query(8).await + } + + #[tokio::test] + async fn q9() -> Result<()> { + verify_query(9).await + } + + #[tokio::test] + async fn q10() -> Result<()> { + verify_query(10).await + } + + #[tokio::test] + async fn q11() -> Result<()> { + verify_query(11).await + } + + #[tokio::test] + async fn q12() -> Result<()> { + verify_query(12).await + } + + #[tokio::test] + async fn q13() -> Result<()> { + verify_query(13).await + } + + #[tokio::test] + async fn q14() -> Result<()> { + verify_query(14).await + } + + #[tokio::test] + async fn q15() -> Result<()> { + verify_query(15).await + } + + #[tokio::test] + async fn q16() -> Result<()> { + verify_query(16).await + } + + #[tokio::test] + async fn q17() -> Result<()> { + verify_query(17).await + } + + #[tokio::test] + async fn q18() -> Result<()> { + verify_query(18).await + } + + #[tokio::test] + async fn q19() -> Result<()> { + verify_query(19).await + } + + #[tokio::test] + async fn q20() -> Result<()> { + verify_query(20).await + } + + #[tokio::test] + async fn q21() -> Result<()> { + verify_query(21).await + } + + #[tokio::test] + async fn q22() -> Result<()> { + verify_query(22).await + } + + /// Specialised String representation + fn col_str(column: &ArrayRef, row_index: usize) -> String { + if column.is_null(row_index) { + return "NULL".to_string(); + } + + // Special case ListArray as there is no pretty print support for it yet + if let DataType::FixedSizeList(_, n) = column.data_type() { + let array = column + .as_any() + .downcast_ref::() + .unwrap() + .value(row_index); + + let mut r = Vec::with_capacity(*n as usize); + for i in 0..*n { + r.push(col_str(&array, i as usize)); + } + return format!("[{}]", r.join(",")); + } + + array_value_to_string(column, row_index).unwrap() + } + + /// Converts the results into a 2d array of strings, `result[row][column]` + /// Special cases nulls to NULL for testing + fn result_vec(results: &[RecordBatch]) -> Vec> { + let mut result = vec![]; + for batch in results { + for row_index in 0..batch.num_rows() { + let row_vec = batch + .columns() + .iter() + .map(|column| col_str(column, row_index)) + .collect(); + result.push(row_vec); + } + } + result + } + + fn get_answer_schema(n: usize) -> Schema { + match n { + 1 => Schema::new(vec![ + Field::new("l_returnflag", DataType::Utf8, true), + Field::new("l_linestatus", DataType::Utf8, true), + Field::new("sum_qty", DataType::Float64, true), + Field::new("sum_base_price", DataType::Float64, true), + Field::new("sum_disc_price", DataType::Float64, true), + Field::new("sum_charge", DataType::Float64, true), + Field::new("avg_qty", DataType::Float64, true), + Field::new("avg_price", DataType::Float64, true), + Field::new("avg_disc", DataType::Float64, true), + Field::new("count_order", DataType::UInt64, true), + ]), + + 2 => Schema::new(vec![ + Field::new("s_acctbal", DataType::Float64, true), + Field::new("s_name", DataType::Utf8, true), + Field::new("n_name", DataType::Utf8, true), + Field::new("p_partkey", DataType::Int32, true), + Field::new("p_mfgr", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + Field::new("s_phone", DataType::Utf8, true), + Field::new("s_comment", DataType::Utf8, true), + ]), + + 3 => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int32, true), + Field::new("revenue", DataType::Float64, true), + Field::new("o_orderdat", DataType::Date32(DateUnit::Day), true), + Field::new("o_shippriority", DataType::Int32, true), + ]), + + 4 => Schema::new(vec![ + Field::new("o_orderpriority", DataType::Utf8, true), + Field::new("order_count", DataType::Int32, true), + ]), + + 5 => Schema::new(vec![ + Field::new("n_name", DataType::Utf8, true), + Field::new("revenue", DataType::Float64, true), + ]), + + 6 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]), + + 7 => Schema::new(vec![ + Field::new("supp_nation", DataType::Utf8, true), + Field::new("cust_nation", DataType::Utf8, true), + Field::new("l_year", DataType::Int32, true), + Field::new("revenue", DataType::Float64, true), + ]), + + 8 => Schema::new(vec![ + Field::new("o_year", DataType::Int32, true), + Field::new("mkt_share", DataType::Float64, true), + ]), + + 9 => Schema::new(vec![ + Field::new("nation", DataType::Utf8, true), + Field::new("o_year", DataType::Int32, true), + Field::new("sum_profit", DataType::Float64, true), + ]), + + 10 => Schema::new(vec![ + Field::new("c_custkey", DataType::Int32, true), + Field::new("c_name", DataType::Utf8, true), + Field::new("revenue", DataType::Float64, true), + Field::new("c_acctbal", DataType::Float64, true), + Field::new("n_name", DataType::Utf8, true), + Field::new("c_address", DataType::Utf8, true), + Field::new("c_phone", DataType::Utf8, true), + Field::new("c_comment", DataType::Utf8, true), + ]), + + 11 => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int32, true), + Field::new("value", DataType::Float64, true), + ]), + + 12 => Schema::new(vec![ + Field::new("l_shipmode", DataType::Utf8, true), + Field::new("high_line_count", DataType::Int64, true), + Field::new("low_line_count", DataType::Int64, true), + ]), + + 13 => Schema::new(vec![ + Field::new("c_count", DataType::Int64, true), + Field::new("custdist", DataType::Int64, true), + ]), + + 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), + + 15 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), + + 16 => Schema::new(vec![ + Field::new("p_brand", DataType::Utf8, true), + Field::new("p_type", DataType::Utf8, true), + Field::new("c_phone", DataType::Int32, true), + Field::new("c_comment", DataType::Int32, true), + ]), + + 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), + + 18 => Schema::new(vec![ + Field::new("c_name", DataType::Utf8, true), + Field::new("c_custkey", DataType::Int32, true), + Field::new("o_orderkey", DataType::Int32, true), + Field::new("o_orderdat", DataType::Date32(DateUnit::Day), true), + Field::new("o_totalprice", DataType::Float64, true), + Field::new("sum_l_quantity", DataType::Float64, true), + ]), + + 19 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]), + + 20 => Schema::new(vec![ + Field::new("s_name", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + ]), + + 21 => Schema::new(vec![ + Field::new("s_name", DataType::Utf8, true), + Field::new("numwait", DataType::Int32, true), + ]), + + 22 => Schema::new(vec![ + Field::new("cntrycode", DataType::Int32, true), + Field::new("numcust", DataType::Int32, true), + Field::new("totacctbal", DataType::Float64, true), + ]), + + _ => unimplemented!(), + } + } + + // convert expected schema to all utf8 so columns can be read as strings to be parsed separately + // this is due to the fact that the csv parser cannot handle leading/trailing spaces + fn string_schema(schema: Schema) -> Schema { + Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + Field::name(&field), + DataType::Utf8, + Field::is_nullable(&field), + ) + }) + .collect::>(), + ) + } + + // convert the schema to the same but with all columns set to nullable=true. + // this allows direct schema comparison ignoring nullable. + fn nullable_schema(schema: Arc) -> Schema { + Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + Field::name(&field), + Field::data_type(&field).to_owned(), + true, + ) + }) + .collect::>(), + ) + } + + async fn verify_query(n: usize) -> Result<()> { + if let Ok(path) = env::var("TPCH_DATA") { + // load expected answers from tpch-dbgen + // read csv as all strings, trim and cast to expected type as the csv string + // to value parser does not handle data with leading/trailing spaces + let mut ctx = ExecutionContext::new(); + let schema = string_schema(get_answer_schema(n)); + let options = CsvReadOptions::new() + .schema(&schema) + .delimiter(b'|') + .file_extension(".out"); + let df = ctx.read_csv(&format!("{}/answers/q{}.out", path, n), options)?; + let df = df.select( + get_answer_schema(n) + .fields() + .iter() + .map(|field| { + Expr::Alias( + Box::new(Cast { + expr: Box::new(trim(col(Field::name(&field)))), + data_type: Field::data_type(&field).to_owned(), + }), + Field::name(&field).to_string(), + ) + }) + .collect::>(), + )?; + let expected = df.collect().await?; + + // run the query to compute actual results of the query + let opt = BenchmarkOpt { + query: n, + debug: false, + iterations: 1, + concurrency: 2, + batch_size: 4096, + path: PathBuf::from(path.to_string()), + file_format: "tbl".to_string(), + mem_table: false, + }; + let actual = benchmark(opt).await?; + + // assert schema equality without comparing nullable values + assert_eq!( + nullable_schema(expected[0].schema()), + nullable_schema(actual[0].schema()) + ); + + // convert both datasets to Vec> for simple comparison + let expected_vec = result_vec(&expected); + let actual_vec = result_vec(&actual); + + // basic result comparison + assert_eq!(expected_vec.len(), actual_vec.len()); + + // compare each row. this works as all TPC-H queries have determinisically ordered results + for i in 0..actual_vec.len() { + assert_eq!(expected_vec[i], actual_vec[i]); + } + } else { + println!("TPCH_DATA environment variable not set, skipping test"); + } + + Ok(()) + } +} diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 309b75bc6b1..c8a4804ae6c 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,7 +28,7 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, length, lit, max, min, sum, JoinType, - Partitioning, + array, avg, col, concat, count, create_udf, length, lit, lower, max, min, sum, trim, + upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions; From 15503ee1d2452196b15b3ceae6effef6d3ce63ab Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Mon, 28 Dec 2020 10:27:32 -0700 Subject: [PATCH 02/12] ARROW-11042: [Rust][DataFusion] Increase default batch size This increases the default batch size 8x from `4096` to `32768` as it improves performance of quite some operations. I just increased the size until performance didn't increase on my machine. Note that CSV reading also is faster on bigger batches on the bigger data sources. This PR ``` Loading table 'part' into memory Loaded table 'part' into memory in 125 ms Loading table 'supplier' into memory Loaded table 'supplier' into memory in 10 ms Loading table 'partsupp' into memory Loaded table 'partsupp' into memory in 381 ms Loading table 'customer' into memory Loaded table 'customer' into memory in 126 ms Loading table 'orders' into memory Loaded table 'orders' into memory in 961 ms Loading table 'lineitem' into memory Loaded table 'lineitem' into memory in 6382 ms Loading table 'nation' into memory Loaded table 'nation' into memory in 2 ms Loading table 'region' into memory Loaded table 'region' into memory in 2 ms Query 12 iteration 0 took 220.2 ms Query 12 iteration 1 took 223.2 ms Query 12 iteration 2 took 222.4 ms Query 12 iteration 3 took 222.2 ms Query 12 iteration 4 took 221.8 ms Query 12 iteration 5 took 222.0 ms Query 12 iteration 6 took 223.1 ms Query 12 iteration 7 took 223.7 ms Query 12 iteration 8 took 222.5 ms Query 12 iteration 9 took 222.9 ms Query 12 avg time: 222.40 ms ``` Master ``` Loading table 'part' into memory Loaded table 'part' into memory in 116 ms Loading table 'supplier' into memory Loaded table 'supplier' into memory in 7 ms Loading table 'partsupp' into memory Loaded table 'partsupp' into memory in 386 ms Loading table 'customer' into memory Loaded table 'customer' into memory in 115 ms Loading table 'orders' into memory Loaded table 'orders' into memory in 1048 ms Loading table 'lineitem' into memory Loaded table 'lineitem' into memory in 7673 ms Loading table 'nation' into memory Loaded table 'nation' into memory in 0 ms Loading table 'region' into memory Loaded table 'region' into memory in 0 ms Query 12 iteration 0 took 596.1 ms Query 12 iteration 1 took 602.0 ms Query 12 iteration 2 took 608.1 ms Query 12 iteration 3 took 607.9 ms Query 12 iteration 4 took 613.5 ms Query 12 iteration 5 took 615.3 ms Query 12 iteration 6 took 611.6 ms Query 12 iteration 7 took 609.8 ms Query 12 iteration 8 took 615.7 ms Query 12 iteration 9 took 616.9 ms Query 12 avg time: 609.68 ms ``` Query 1 also improves a bit (but smaller improvement) PR. ``` Query 1 iteration 0 took 653.0 ms Query 1 iteration 1 took 653.4 ms Query 1 iteration 2 took 652.3 ms Query 1 iteration 3 took 658.9 ms Query 1 iteration 4 took 655.1 ms Query 1 iteration 5 took 662.0 ms Query 1 iteration 6 took 659.7 ms Query 1 iteration 7 took 662.7 ms Query 1 iteration 8 took 669.0 ms Query 1 iteration 9 took 665.7 ms Query 1 avg time: 659.19 ms ``` Master: ``` Query 1 iteration 0 took 708.8 ms Query 1 iteration 1 took 714.5 ms Query 1 iteration 2 took 700.4 ms Query 1 iteration 3 took 713.7 ms Query 1 iteration 4 took 707.5 ms Query 1 iteration 5 took 727.8 ms Query 1 iteration 6 took 727.9 ms Query 1 iteration 7 took 721.3 ms Query 1 iteration 8 took 717.3 ms Query 1 iteration 9 took 729.4 ms Query 1 avg time: 716.85 ms ``` Closes #9021 from Dandandan/batch_size Authored-by: Heres, Daniel Signed-off-by: Andy Grove --- rust/benchmarks/src/bin/tpch.rs | 2 +- rust/datafusion/src/execution/context.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index eb789baa7a2..cdffd3593cc 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -52,7 +52,7 @@ struct BenchmarkOpt { concurrency: usize, /// Batch size when reading CSV or Parquet files - #[structopt(short = "s", long = "batch-size", default_value = "4096")] + #[structopt(short = "s", long = "batch-size", default_value = "32768")] batch_size: usize, /// Path to data files diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index a489c117d5e..481011cf758 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -492,7 +492,7 @@ impl ExecutionConfig { pub fn new() -> Self { Self { concurrency: num_cpus::get(), - batch_size: 4096, + batch_size: 32768, query_planner: Arc::new(DefaultQueryPlanner {}), } } From c46fd102678fd22b9781642437ad8821f907d9db Mon Sep 17 00:00:00 2001 From: "Heres, Daniel" Date: Mon, 28 Dec 2020 10:37:40 -0700 Subject: [PATCH 03/12] ARROW-11046: [Rust][DataFusion] Support `count_distinct` in DataFrame API Adds `count_distinct` function. Closes #9028 from Dandandan/count_distinct Authored-by: Heres, Daniel Signed-off-by: Andy Grove --- rust/datafusion/src/execution/dataframe_impl.rs | 3 ++- rust/datafusion/src/logical_plan/expr.rs | 9 +++++++++ rust/datafusion/src/logical_plan/mod.rs | 6 +++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 7b47aa218dc..db8b86b0137 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -207,6 +207,7 @@ mod tests { avg(col("c12")), sum(col("c12")), count(col("c12")), + count_distinct(col("c12")), ]; let df = df.aggregate(group_expr, aggr_expr)?; @@ -214,7 +215,7 @@ mod tests { let plan = df.to_logical_plan(); // build same plan using SQL API - let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \ + let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12), COUNT(DISTINCT c12) \ FROM aggregate_test_100 \ GROUP BY c1"; let sql_plan = create_plan(sql)?; diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs index 0ae26a37364..4aee03c4c12 100644 --- a/rust/datafusion/src/logical_plan/expr.rs +++ b/rust/datafusion/src/logical_plan/expr.rs @@ -570,6 +570,15 @@ pub fn count(expr: Expr) -> Expr { } } +/// Create an expression to represent the count(distinct) aggregate function +pub fn count_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::Count, + distinct: true, + args: vec![expr], + } +} + /// Whether it can be represented as a literal expression pub trait Literal { /// convert the value to a Literal expression diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index f810b0162c4..0d37da6f6a3 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -35,9 +35,9 @@ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos, - count, create_udaf, create_udf, exp, exprlist_to_fields, floor, length, lit, ln, - log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, tan, trim, trunc, - upper, when, Expr, Literal, + count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, + length, lit, ln, log10, log2, lower, max, min, or, round, signum, sin, sqrt, sum, + tan, trim, trunc, upper, when, Expr, Literal, }; pub use extension::UserDefinedLogicalNode; pub use operators::Operator; From aab219019e5aa476445a791b67b2bc484efe4e48 Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 11 Dec 2020 16:56:49 -0500 Subject: [PATCH 04/12] [Rust] No need to specify binaries in the conventional location --- rust/integration-testing/Cargo.toml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/rust/integration-testing/Cargo.toml b/rust/integration-testing/Cargo.toml index 1c2687086fb..528341088fe 100644 --- a/rust/integration-testing/Cargo.toml +++ b/rust/integration-testing/Cargo.toml @@ -32,15 +32,3 @@ serde = { version = "1.0", features = ["rc"] } serde_derive = "1.0" serde_json = { version = "1.0", features = ["preserve_order"] } hex = "0.4" - -[[bin]] -name = "arrow-file-to-stream" -path = "src/bin/arrow-file-to-stream.rs" - -[[bin]] -name = "arrow-stream-to-file" -path = "src/bin/arrow-stream-to-file.rs" - -[[bin]] -name = "arrow-json-integration-test" -path = "src/bin/arrow-json-integration-test.rs" From d9fd8c3b6609adb5aebdffe1b3f6eda54c64a5eb Mon Sep 17 00:00:00 2001 From: Jake Goulding Date: Fri, 11 Dec 2020 16:32:45 -0500 Subject: [PATCH 05/12] [Rust] Use generic arguments --- rust/arrow/src/ipc/writer.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs index 5161d548ca6..52fe178168a 100644 --- a/rust/arrow/src/ipc/writer.rs +++ b/rust/arrow/src/ipc/writer.rs @@ -554,10 +554,9 @@ pub struct EncodedData { /// Arrow buffers to be written, should be an empty vec for schema messages pub arrow_data: Vec, } - /// Write a message's IPC data and buffers, returning metadata and buffer data lengths written -fn write_message( - mut writer: &mut BufWriter, +pub fn write_message( + mut writer: W, encoded: EncodedData, write_options: &IpcWriteOptions, ) -> Result<(usize, usize)> { @@ -602,7 +601,7 @@ fn write_message( Ok((aligned_size, body_len)) } -fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Result { +fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { let len = data.len() as u32; let pad_len = pad_to_8(len) as u32; let total_len = len + pad_len; @@ -620,7 +619,7 @@ fn write_body_buffers(writer: &mut BufWriter, data: &[u8]) -> Resul /// Write a record batch to the writer, writing the message size before the message /// if the record batch is being written to a stream fn write_continuation( - writer: &mut BufWriter, + mut writer: W, write_options: &IpcWriteOptions, total_len: i32, ) -> Result { From c0d1e6c214e7e7879dcb0834ad6a62687ca6a663 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Tue, 15 Dec 2020 16:45:44 -0500 Subject: [PATCH 06/12] [Rust] Extract conversion of EncodedData to FlightData --- rust/arrow-flight/src/utils.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index c2e01fb6ccc..a764030e246 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -23,7 +23,7 @@ use crate::{FlightData, SchemaResult}; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result}; -use arrow::ipc::{convert, reader, writer, writer::IpcWriteOptions}; +use arrow::ipc::{convert, reader, writer, writer::EncodedData, writer::IpcWriteOptions}; use arrow::record_batch::RecordBatch; /// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries @@ -42,13 +42,18 @@ pub fn flight_data_from_arrow_batch( encoded_dictionaries .into_iter() .chain(std::iter::once(encoded_batch)) - .map(|data| FlightData { - flight_descriptor: None, - app_metadata: vec![], + .map(Into::into) + .collect() +} + +impl From for FlightData { + fn from(data: EncodedData) -> Self { + FlightData { data_header: data.ipc_message, data_body: data.arrow_data, - }) - .collect() + ..Default::default() + } + } } /// Convert a `Schema` to `SchemaResult` by converting to an IPC message From f798e798f73c61dc8a62d54755d7238c1ae084c2 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Tue, 15 Dec 2020 16:50:35 -0500 Subject: [PATCH 07/12] [Rust] Be clearer about returned data This function returns, potentially, a few dictionaries, and will always return one record batch. Return these as a tuple rather than putting them together in one vec to clarify. --- rust/arrow-flight/src/utils.rs | 13 ++++++------- rust/datafusion/examples/flight_server.rs | 7 +++++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index a764030e246..aa567caaf39 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -27,11 +27,11 @@ use arrow::ipc::{convert, reader, writer, writer::EncodedData, writer::IpcWriteO use arrow::record_batch::RecordBatch; /// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries -/// and values +/// and a `FlightData` representing the bytes of the batch's values pub fn flight_data_from_arrow_batch( batch: &RecordBatch, options: &IpcWriteOptions, -) -> Vec { +) -> (Vec, FlightData) { let data_gen = writer::IpcDataGenerator::default(); let mut dictionary_tracker = writer::DictionaryTracker::new(false); @@ -39,11 +39,10 @@ pub fn flight_data_from_arrow_batch( .encoded_batch(batch, &mut dictionary_tracker, &options) .expect("DictionaryTracker configured above to not error on replacement"); - encoded_dictionaries - .into_iter() - .chain(std::iter::once(encoded_batch)) - .map(Into::into) - .collect() + let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + (flight_dictionaries, flight_batch) } impl From for FlightData { diff --git a/rust/datafusion/examples/flight_server.rs b/rust/datafusion/examples/flight_server.rs index d73405026bb..a9b3d251464 100644 --- a/rust/datafusion/examples/flight_server.rs +++ b/rust/datafusion/examples/flight_server.rs @@ -125,11 +125,14 @@ impl FlightService for FlightServiceImpl { let mut batches: Vec> = results .iter() .flat_map(|batch| { - let flight_data = + let (flight_dictionaries, flight_batch) = arrow_flight::utils::flight_data_from_arrow_batch( batch, &options, ); - flight_data.into_iter().map(Ok) + flight_dictionaries + .into_iter() + .chain(std::iter::once(flight_batch)) + .map(Ok) }) .collect(); From 5628adc41e9950a989b3308ae267bc9480ae490c Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Tue, 15 Dec 2020 16:59:36 -0500 Subject: [PATCH 08/12] [Rust] Extract test functions to the integration testing lib to share --- .../src/bin/arrow-json-integration-test.rs | 570 +---------------- rust/integration-testing/src/lib.rs | 575 ++++++++++++++++++ 2 files changed, 577 insertions(+), 568 deletions(-) diff --git a/rust/integration-testing/src/bin/arrow-json-integration-test.rs b/rust/integration-testing/src/bin/arrow-json-integration-test.rs index b1bec677cf1..cd89a8edf1d 100644 --- a/rust/integration-testing/src/bin/arrow-json-integration-test.rs +++ b/rust/integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,27 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::fs::File; -use std::io::BufReader; -use std::sync::Arc; use clap::{App, Arg}; -use hex::decode; -use serde_json::Value; -use arrow::array::*; -use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; -use arrow::record_batch::RecordBatch; -use arrow::{ - buffer::Buffer, - buffer::MutableBuffer, - datatypes::ToByteSlice, - util::{bit_util, integration_util::*}, -}; +use arrow::util::integration_util::*; +use arrow_integration_testing::read_json_file; fn main() -> Result<()> { let matches = App::new("rust arrow-json-integration-test") @@ -93,520 +81,6 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> Ok(()) } -fn record_batch_from_json( - schema: &Schema, - json_batch: ArrowJsonBatch, - json_dictionaries: Option<&HashMap>, -) -> Result { - let mut columns = vec![]; - - for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { - let col = array_from_json(field, json_col, json_dictionaries)?; - columns.push(col); - } - - RecordBatch::try_new(Arc::new(schema.clone()), columns) -} - -/// Construct an Arrow array from a partially typed JSON column -fn array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dictionaries: Option<&HashMap>, -) -> Result { - match field.data_type() { - DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), - DataType::Boolean => { - let mut b = BooleanBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_bool().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int8 => { - let mut b = Int8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to get {:?} as int64", - value - )) - })? as i8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int16 => { - let mut b = Int16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Int32 - | DataType::Date32(DateUnit::Day) - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - let mut b = Int32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_i64().unwrap() as i32), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::Int64 - | DataType::Date64(DateUnit::Millisecond) - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) - | DataType::Interval(IntervalUnit::DayTime) => { - let mut b = Int64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(match value { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => { - s.parse().expect("Unable to parse string as i64") - } - _ => panic!("Unable to parse {:?} as number", value), - }), - _ => b.append_null(), - }?; - } - let array = Arc::new(b.finish()) as ArrayRef; - arrow::compute::cast(&array, field.data_type()) - } - DataType::UInt8 => { - let mut b = UInt8Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u8), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt16 => { - let mut b = UInt16Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u16), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt32 => { - let mut b = UInt32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_u64().unwrap() as u32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::UInt64 => { - let mut b = UInt64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value( - value - .as_str() - .unwrap() - .parse() - .expect("Unable to parse string as u64"), - ), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float32 => { - let mut b = Float32Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap() as f32), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Float64 => { - let mut b = Float64Builder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_f64().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Binary => { - let mut b = BinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeBinary => { - let mut b = LargeBinaryBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::Utf8 => { - let mut b = StringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::LargeUtf8 => { - let mut b = LargeStringBuilder::new(json_col.count); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => b.append_value(value.as_str().unwrap()), - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::FixedSizeBinary(len) => { - let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); - for (is_valid, value) in json_col - .validity - .as_ref() - .unwrap() - .iter() - .zip(json_col.data.unwrap()) - { - match is_valid { - 1 => { - let v = hex::decode(value.as_str().unwrap()).unwrap(); - b.append_value(&v) - } - _ => b.append_null(), - }?; - } - Ok(Arc::new(b.finish())) - } - DataType::List(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| v.as_i64().unwrap() as i32) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(ListArray::from(list_data))) - } - DataType::LargeList(child_field) => { - let null_buf = create_null_buf(&json_col); - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let offsets: Vec = json_col - .offset - .unwrap() - .iter() - .map(|v| match v { - Value::Number(n) => n.as_i64().unwrap(), - Value::String(s) => s.parse::().unwrap(), - _ => panic!("64-bit offset must be either string or number"), - }) - .collect(); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .offset(0) - .add_buffer(Buffer::from(&offsets.to_byte_slice())) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(LargeListArray::from(list_data))) - } - DataType::FixedSizeList(child_field, _) => { - let children = json_col.children.clone().unwrap(); - let child_array = array_from_json( - &child_field, - children.get(0).unwrap().clone(), - dictionaries, - )?; - let null_buf = create_null_buf(&json_col); - let list_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .add_child_data(child_array.data()) - .null_bit_buffer(null_buf) - .build(); - Ok(Arc::new(FixedSizeListArray::from(list_data))) - } - DataType::Struct(fields) => { - // construct struct with null data - let null_buf = create_null_buf(&json_col); - let mut array_data = ArrayData::builder(field.data_type().clone()) - .len(json_col.count) - .null_bit_buffer(null_buf); - - for (field, col) in fields.iter().zip(json_col.children.unwrap()) { - let array = array_from_json(field, col, dictionaries)?; - array_data = array_data.add_child_data(array.data()); - } - - let array = StructArray::from(array_data.build()); - Ok(Arc::new(array)) - } - DataType::Dictionary(key_type, value_type) => { - let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find dict_id for field {:?}", - field - )) - })?; - // find dictionary - let dictionary = dictionaries - .ok_or_else(|| { - ArrowError::JsonError(format!( - "Unable to find any dictionaries for field {:?}", - field - )) - })? - .get(&dict_id); - match dictionary { - Some(dictionary) => dictionary_array_from_json( - field, json_col, key_type, value_type, dictionary, - ), - None => Err(ArrowError::JsonError(format!( - "Unable to find dictionary for field {:?}", - field - ))), - } - } - t => Err(ArrowError::JsonError(format!( - "data type {:?} not supported", - t - ))), - } -} - -fn dictionary_array_from_json( - field: &Field, - json_col: ArrowJsonColumn, - dict_key: &DataType, - dict_value: &DataType, - dictionary: &ArrowJsonDictionaryBatch, -) -> Result { - match dict_key { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => { - let null_buf = create_null_buf(&json_col); - - // build the key data into a buffer, then construct values separately - let key_field = Field::new_dict( - "key", - dict_key.clone(), - field.is_nullable(), - field - .dict_id() - .expect("Dictionary fields must have a dict_id value"), - field - .dict_is_ordered() - .expect("Dictionary fields must have a dict_is_ordered value"), - ); - let keys = array_from_json(&key_field, json_col, None)?; - // note: not enough info on nullability of dictionary - let value_field = Field::new("value", dict_value.clone(), true); - println!("dictionary value type: {:?}", dict_value); - let values = - array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; - - // convert key and value to dictionary data - let dict_data = ArrayData::builder(field.data_type().clone()) - .len(keys.len()) - .add_buffer(keys.data().buffers()[0].clone()) - .null_bit_buffer(null_buf) - .add_child_data(values.data()) - .build(); - - let array = match dict_key { - DataType::Int8 => { - Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef - } - DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), - DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), - DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), - DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), - DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), - DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), - DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), - _ => unreachable!(), - }; - Ok(array) - } - _ => Err(ArrowError::JsonError(format!( - "Dictionary key type {:?} not supported", - dict_key - ))), - } -} - -/// A helper to create a null buffer from a Vec -fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { - let num_bytes = bit_util::ceil(json_col.count, 8); - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); - json_col - .validity - .clone() - .unwrap() - .iter() - .enumerate() - .for_each(|(i, v)| { - let null_slice = null_buf.data_mut(); - if *v != 0 { - bit_util::set_bit(null_slice, i); - } - }); - null_buf.freeze() -} - fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { eprintln!("Converting {} to {}", arrow_name, json_name); @@ -702,43 +176,3 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { Ok(()) } - -struct ArrowFile { - schema: Schema, - // we can evolve this into a concrete Arrow type - // this is temporarily not being read from - _dictionaries: HashMap, - batches: Vec, -} - -fn read_json_file(json_name: &str) -> Result { - let json_file = File::open(json_name)?; - let reader = BufReader::new(json_file); - let arrow_json: Value = serde_json::from_reader(reader).unwrap(); - let schema = Schema::from(&arrow_json["schema"])?; - // read dictionaries - let mut dictionaries = HashMap::new(); - if let Some(dicts) = arrow_json.get("dictionaries") { - for d in dicts - .as_array() - .expect("Unable to get dictionaries as array") - { - let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) - .expect("Unable to get dictionary from JSON"); - // TODO: convert to a concrete Arrow type - dictionaries.insert(json_dict.id, json_dict); - } - } - - let mut batches = vec![]; - for b in arrow_json["batches"].as_array().unwrap() { - let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); - let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; - batches.push(batch); - } - Ok(ArrowFile { - schema, - _dictionaries: dictionaries, - batches, - }) -} diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/lib.rs index 596017a79bd..b93f1c4aa51 100644 --- a/rust/integration-testing/src/lib.rs +++ b/rust/integration-testing/src/lib.rs @@ -16,3 +16,578 @@ // under the License. //! Common code used in the integration test binaries + +use hex::decode; +use serde_json::Value; + +use arrow::util::integration_util::ArrowJsonBatch; + +use arrow::array::*; +use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema}; +use arrow::error::{ArrowError, Result}; +use arrow::record_batch::RecordBatch; +use arrow::{ + buffer::Buffer, + buffer::MutableBuffer, + datatypes::ToByteSlice, + util::{bit_util, integration_util::*}, +}; + +use std::collections::HashMap; +use std::fs::File; +use std::io::BufReader; +use std::sync::Arc; + +pub struct ArrowFile { + pub schema: Schema, + // we can evolve this into a concrete Arrow type + // this is temporarily not being read from + pub _dictionaries: HashMap, + pub batches: Vec, +} + +pub fn read_json_file(json_name: &str) -> Result { + let json_file = File::open(json_name)?; + let reader = BufReader::new(json_file); + let arrow_json: Value = serde_json::from_reader(reader).unwrap(); + let schema = Schema::from(&arrow_json["schema"])?; + // read dictionaries + let mut dictionaries = HashMap::new(); + if let Some(dicts) = arrow_json.get("dictionaries") { + for d in dicts + .as_array() + .expect("Unable to get dictionaries as array") + { + let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone()) + .expect("Unable to get dictionary from JSON"); + // TODO: convert to a concrete Arrow type + dictionaries.insert(json_dict.id, json_dict); + } + } + + let mut batches = vec![]; + for b in arrow_json["batches"].as_array().unwrap() { + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; + batches.push(batch); + } + Ok(ArrowFile { + schema, + _dictionaries: dictionaries, + batches, + }) +} + +fn record_batch_from_json( + schema: &Schema, + json_batch: ArrowJsonBatch, + json_dictionaries: Option<&HashMap>, +) -> Result { + let mut columns = vec![]; + + for (field, json_col) in schema.fields().iter().zip(json_batch.columns) { + let col = array_from_json(field, json_col, json_dictionaries)?; + columns.push(col); + } + + RecordBatch::try_new(Arc::new(schema.clone()), columns) +} + +/// Construct an Arrow array from a partially typed JSON column +fn array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dictionaries: Option<&HashMap>, +) -> Result { + match field.data_type() { + DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))), + DataType::Boolean => { + let mut b = BooleanBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_bool().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int8 => { + let mut b = Int8Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to get {:?} as int64", + value + )) + })? as i8), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int16 => { + let mut b = Int16Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i16), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Int32 + | DataType::Date32(DateUnit::Day) + | DataType::Time32(_) + | DataType::Interval(IntervalUnit::YearMonth) => { + let mut b = Int32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_i64().unwrap() as i32), + _ => b.append_null(), + }?; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::Int64 + | DataType::Date64(DateUnit::Millisecond) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Duration(_) + | DataType::Interval(IntervalUnit::DayTime) => { + let mut b = Int64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(match value { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => { + s.parse().expect("Unable to parse string as i64") + } + _ => panic!("Unable to parse {:?} as number", value), + }), + _ => b.append_null(), + }?; + } + let array = Arc::new(b.finish()) as ArrayRef; + arrow::compute::cast(&array, field.data_type()) + } + DataType::UInt8 => { + let mut b = UInt8Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u8), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt16 => { + let mut b = UInt16Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u16), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt32 => { + let mut b = UInt32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_u64().unwrap() as u32), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::UInt64 => { + let mut b = UInt64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value( + value + .as_str() + .unwrap() + .parse() + .expect("Unable to parse string as u64"), + ), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Float32 => { + let mut b = Float32Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap() as f32), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Float64 => { + let mut b = Float64Builder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_f64().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Binary => { + let mut b = BinaryBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeBinary => { + let mut b = LargeBinaryBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::Utf8 => { + let mut b = StringBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::LargeUtf8 => { + let mut b = LargeStringBuilder::new(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap()), + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::FixedSizeBinary(len) => { + let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let v = hex::decode(value.as_str().unwrap()).unwrap(); + b.append_value(&v) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } + DataType::List(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| v.as_i64().unwrap() as i32) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(ListArray::from(list_data))) + } + DataType::LargeList(child_field) => { + let null_buf = create_null_buf(&json_col); + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let offsets: Vec = json_col + .offset + .unwrap() + .iter() + .map(|v| match v { + Value::Number(n) => n.as_i64().unwrap(), + Value::String(s) => s.parse::().unwrap(), + _ => panic!("64-bit offset must be either string or number"), + }) + .collect(); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .offset(0) + .add_buffer(Buffer::from(&offsets.to_byte_slice())) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(LargeListArray::from(list_data))) + } + DataType::FixedSizeList(child_field, _) => { + let children = json_col.children.clone().unwrap(); + let child_array = array_from_json( + &child_field, + children.get(0).unwrap().clone(), + dictionaries, + )?; + let null_buf = create_null_buf(&json_col); + let list_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .add_child_data(child_array.data()) + .null_bit_buffer(null_buf) + .build(); + Ok(Arc::new(FixedSizeListArray::from(list_data))) + } + DataType::Struct(fields) => { + // construct struct with null data + let null_buf = create_null_buf(&json_col); + let mut array_data = ArrayData::builder(field.data_type().clone()) + .len(json_col.count) + .null_bit_buffer(null_buf); + + for (field, col) in fields.iter().zip(json_col.children.unwrap()) { + let array = array_from_json(field, col, dictionaries)?; + array_data = array_data.add_child_data(array.data()); + } + + let array = StructArray::from(array_data.build()); + Ok(Arc::new(array)) + } + DataType::Dictionary(key_type, value_type) => { + let dict_id = field.dict_id().ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find dict_id for field {:?}", + field + )) + })?; + // find dictionary + let dictionary = dictionaries + .ok_or_else(|| { + ArrowError::JsonError(format!( + "Unable to find any dictionaries for field {:?}", + field + )) + })? + .get(&dict_id); + match dictionary { + Some(dictionary) => dictionary_array_from_json( + field, json_col, key_type, value_type, dictionary, + ), + None => Err(ArrowError::JsonError(format!( + "Unable to find dictionary for field {:?}", + field + ))), + } + } + t => Err(ArrowError::JsonError(format!( + "data type {:?} not supported", + t + ))), + } +} + +fn dictionary_array_from_json( + field: &Field, + json_col: ArrowJsonColumn, + dict_key: &DataType, + dict_value: &DataType, + dictionary: &ArrowJsonDictionaryBatch, +) -> Result { + match dict_key { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + let null_buf = create_null_buf(&json_col); + + // build the key data into a buffer, then construct values separately + let key_field = Field::new_dict( + "key", + dict_key.clone(), + field.is_nullable(), + field + .dict_id() + .expect("Dictionary fields must have a dict_id value"), + field + .dict_is_ordered() + .expect("Dictionary fields must have a dict_is_ordered value"), + ); + let keys = array_from_json(&key_field, json_col, None)?; + // note: not enough info on nullability of dictionary + let value_field = Field::new("value", dict_value.clone(), true); + println!("dictionary value type: {:?}", dict_value); + let values = + array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?; + + // convert key and value to dictionary data + let dict_data = ArrayData::builder(field.data_type().clone()) + .len(keys.len()) + .add_buffer(keys.data().buffers()[0].clone()) + .null_bit_buffer(null_buf) + .add_child_data(values.data()) + .build(); + + let array = match dict_key { + DataType::Int8 => { + Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef + } + DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)), + DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)), + DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)), + DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)), + DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)), + DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)), + DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)), + _ => unreachable!(), + }; + Ok(array) + } + _ => Err(ArrowError::JsonError(format!( + "Dictionary key type {:?} not supported", + dict_key + ))), + } +} + +/// A helper to create a null buffer from a Vec +fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { + let num_bytes = bit_util::ceil(json_col.count, 8); + let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false); + json_col + .validity + .clone() + .unwrap() + .iter() + .enumerate() + .for_each(|(i, v)| { + let null_slice = null_buf.data_mut(); + if *v != 0 { + bit_util::set_bit(null_slice, i); + } + }); + null_buf.freeze() +} From 2ee1a38ae9074e38ae3db8d08ef657ceeb6c84b8 Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 16 Dec 2020 15:09:22 -0500 Subject: [PATCH 09/12] [Rust] Support specifying dicts for FlightData to RecordBatch --- rust/arrow-flight/src/utils.rs | 4 ++-- rust/datafusion/examples/flight_client.rs | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index aa567caaf39..0ff0fb43214 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -21,6 +21,7 @@ use std::convert::TryFrom; use crate::{FlightData, SchemaResult}; +use arrow::array::ArrayRef; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::{ArrowError, Result}; use arrow::ipc::{convert, reader, writer, writer::EncodedData, writer::IpcWriteOptions}; @@ -117,6 +118,7 @@ impl TryFrom<&SchemaResult> for Schema { pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, + dictionaries_by_field: &[Option], ) -> Option> { // check that the data_header is a record batch message let res = arrow::ipc::root_as_message(&data.data_header[..]); @@ -131,8 +133,6 @@ pub fn flight_data_to_arrow_batch( let message = res.unwrap(); - let dictionaries_by_field = Vec::new(); - message .header_as_record_batch() .ok_or_else(|| { diff --git a/rust/datafusion/examples/flight_client.rs b/rust/datafusion/examples/flight_client.rs index 13fd394d187..1e0b56bc8a8 100644 --- a/rust/datafusion/examples/flight_client.rs +++ b/rust/datafusion/examples/flight_client.rs @@ -62,10 +62,15 @@ async fn main() -> Result<(), Box> { // all the remaining stream messages should be dictionary and record batches let mut results = vec![]; + let dictionaries_by_field = vec![None; schema.fields().len()]; while let Some(flight_data) = stream.message().await? { // the unwrap is infallible and thus safe - let record_batch = - flight_data_to_arrow_batch(&flight_data, schema.clone()).unwrap()?; + let record_batch = flight_data_to_arrow_batch( + &flight_data, + schema.clone(), + &dictionaries_by_field, + ) + .unwrap()?; results.push(record_batch); } From 4f74717d6bd636fb8f56474e5b3c68579d4a4e1e Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 16 Dec 2020 16:40:36 -0500 Subject: [PATCH 10/12] [Rust] Remove Option that is always Some This doesn't seem necessary --- rust/arrow-flight/src/utils.rs | 25 ++++++++--------------- rust/datafusion/examples/flight_client.rs | 4 +--- 2 files changed, 9 insertions(+), 20 deletions(-) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index 0ff0fb43214..b58393147b6 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -119,19 +119,11 @@ pub fn flight_data_to_arrow_batch( data: &FlightData, schema: SchemaRef, dictionaries_by_field: &[Option], -) -> Option> { +) -> Result { // check that the data_header is a record batch message - let res = arrow::ipc::root_as_message(&data.data_header[..]); - - // Catch error. - if let Err(err) = res { - return Some(Err(ArrowError::ParseError(format!( - "Unable to get root as message: {:?}", - err - )))); - } - - let message = res.unwrap(); + let message = arrow::ipc::root_as_message(&data.data_header[..]).map_err(|err| { + ArrowError::ParseError(format!("Unable to get root as message: {:?}", err)) + })?; message .header_as_record_batch() @@ -140,17 +132,16 @@ pub fn flight_data_to_arrow_batch( "Unable to convert flight data header to a record batch".to_string(), ) }) - .map_or_else( - |err| Some(Err(err)), + .map( |batch| { - Some(reader::read_record_batch( + reader::read_record_batch( &data.data_body, batch, schema, &dictionaries_by_field, - )) + ) }, - ) + )? } // TODO: add more explicit conversion that exposes flight descriptor and metadata options diff --git a/rust/datafusion/examples/flight_client.rs b/rust/datafusion/examples/flight_client.rs index 1e0b56bc8a8..2c2954d5a02 100644 --- a/rust/datafusion/examples/flight_client.rs +++ b/rust/datafusion/examples/flight_client.rs @@ -64,13 +64,11 @@ async fn main() -> Result<(), Box> { let mut results = vec![]; let dictionaries_by_field = vec![None; schema.fields().len()]; while let Some(flight_data) = stream.message().await? { - // the unwrap is infallible and thus safe let record_batch = flight_data_to_arrow_batch( &flight_data, schema.clone(), &dictionaries_by_field, - ) - .unwrap()?; + )?; results.push(record_batch); } From 94f69ae48b71408ee66d0c4fb3bbc3412b5f3ced Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Fri, 18 Dec 2020 11:26:18 -0500 Subject: [PATCH 11/12] [Rust] Skip middleware flight integration tests until tonic upgrade Tracked in ARROW-10961. There's a bug in tonic that doesn't handle headers and trailers correctly; it has been fixed but a new version of tonic needs to be released and used to get the fix. --- dev/archery/archery/integration/runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index c1d7a697ab0..77390fd7d61 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -347,7 +347,9 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, description="Authenticate using the BasicAuth protobuf."), Scenario( "middleware", - description="Ensure headers are propagated via middleware."), + description="Ensure headers are propagated via middleware.", + skip="Rust" # TODO(ARROW-10961): tonic upgrade needed + ), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) From 116611c6472abd296fe8a2c35cf1d0ade565303a Mon Sep 17 00:00:00 2001 From: "Carol (Nichols || Goulding)" Date: Wed, 16 Dec 2020 15:18:46 -0500 Subject: [PATCH 12/12] ARROW-8853: [Rust] [Integration Testing] Enable Flight tests --- .../archery/integration/tester_rust.py | 86 ++-- rust/arrow-flight/src/utils.rs | 60 ++- rust/arrow/src/ipc/reader.rs | 2 +- rust/integration-testing/Cargo.toml | 8 +- .../src/bin/flight-test-integration-client.rs | 59 +++ .../src/bin/flight-test-integration-server.rs | 52 +++ .../src/flight_client_scenarios.rs | 20 + .../auth_basic_proto.rs | 109 +++++ .../integration_test.rs | 266 ++++++++++++ .../src/flight_client_scenarios/middleware.rs | 82 ++++ .../src/flight_server_scenarios.rs | 49 +++ .../auth_basic_proto.rs | 226 ++++++++++ .../integration_test.rs | 385 ++++++++++++++++++ .../src/flight_server_scenarios/middleware.rs | 150 +++++++ rust/integration-testing/src/lib.rs | 8 + 15 files changed, 1502 insertions(+), 60 deletions(-) create mode 100644 rust/integration-testing/src/bin/flight-test-integration-client.rs create mode 100644 rust/integration-testing/src/bin/flight-test-integration-server.rs create mode 100644 rust/integration-testing/src/flight_client_scenarios.rs create mode 100644 rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs create mode 100644 rust/integration-testing/src/flight_client_scenarios/integration_test.rs create mode 100644 rust/integration-testing/src/flight_client_scenarios/middleware.rs create mode 100644 rust/integration-testing/src/flight_server_scenarios.rs create mode 100644 rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs create mode 100644 rust/integration-testing/src/flight_server_scenarios/integration_test.rs create mode 100644 rust/integration-testing/src/flight_server_scenarios/middleware.rs diff --git a/dev/archery/archery/integration/tester_rust.py b/dev/archery/archery/integration/tester_rust.py index 23c2d37386a..bca80ebae3c 100644 --- a/dev/archery/archery/integration/tester_rust.py +++ b/dev/archery/archery/integration/tester_rust.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. +import contextlib import os +import subprocess from .tester import Tester from .util import run_cmd, ARROW_ROOT_DEFAULT, log @@ -24,8 +26,8 @@ class RustTester(Tester): PRODUCER = True CONSUMER = True - # FLIGHT_SERVER = True - # FLIGHT_CLIENT = True + FLIGHT_SERVER = True + FLIGHT_CLIENT = True EXE_PATH = os.path.join(ARROW_ROOT_DEFAULT, 'rust/target/debug') @@ -34,11 +36,11 @@ class RustTester(Tester): STREAM_TO_FILE = os.path.join(EXE_PATH, 'arrow-stream-to-file') FILE_TO_STREAM = os.path.join(EXE_PATH, 'arrow-file-to-stream') - # FLIGHT_SERVER_CMD = [ - # os.path.join(EXE_PATH, 'flight-test-integration-server')] - # FLIGHT_CLIENT_CMD = [ - # os.path.join(EXE_PATH, 'flight-test-integration-client'), - # "-host", "localhost"] + FLIGHT_SERVER_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-server')] + FLIGHT_CLIENT_CMD = [ + os.path.join(EXE_PATH, 'flight-test-integration-client'), + "--host", "localhost"] name = 'Rust' @@ -72,34 +74,42 @@ def file_to_stream(self, file_path, stream_path): cmd = [self.FILE_TO_STREAM, file_path, '>', stream_path] self.run_shell_command(cmd) - # @contextlib.contextmanager - # def flight_server(self): - # cmd = self.FLIGHT_SERVER_CMD + ['-port=0'] - # if self.debug: - # log(' '.join(cmd)) - # server = subprocess.Popen(cmd, - # stdout=subprocess.PIPE, - # stderr=subprocess.PIPE) - # try: - # output = server.stdout.readline().decode() - # if not output.startswith("Server listening on localhost:"): - # server.kill() - # out, err = server.communicate() - # raise RuntimeError( - # "Flight-C++ server did not start properly, " - # "stdout:\n{}\n\nstderr:\n{}\n" - # .format(output + out.decode(), err.decode())) - # port = int(output.split(":")[1]) - # yield port - # finally: - # server.kill() - # server.wait(5) - - # def flight_request(self, port, json_path): - # cmd = self.FLIGHT_CLIENT_CMD + [ - # '-port=' + str(port), - # '-path=' + json_path, - # ] - # if self.debug: - # log(' '.join(cmd)) - # run_cmd(cmd) + @contextlib.contextmanager + def flight_server(self, scenario_name=None): + cmd = self.FLIGHT_SERVER_CMD + ['--port=0'] + if scenario_name: + cmd = cmd + ["--scenario", scenario_name] + if self.debug: + log(' '.join(cmd)) + server = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + try: + output = server.stdout.readline().decode() + if not output.startswith("Server listening on localhost:"): + server.kill() + out, err = server.communicate() + raise RuntimeError( + "Flight-Rust server did not start properly, " + "stdout:\n{}\n\nstderr:\n{}\n" + .format(output + out.decode(), err.decode())) + port = int(output.split(":")[1]) + yield port + finally: + server.kill() + server.wait(5) + + def flight_request(self, port, json_path=None, scenario_name=None): + cmd = self.FLIGHT_CLIENT_CMD + [ + '--port=' + str(port), + ] + if json_path: + cmd.extend(('--path', json_path)) + elif scenario_name: + cmd.extend(('--scenario', scenario_name)) + else: + raise TypeError("Must provide one of json_path or scenario_name") + + if self.debug: + log(' '.join(cmd)) + run_cmd(cmd) diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs index b58393147b6..659668c0baf 100644 --- a/rust/arrow-flight/src/utils.rs +++ b/rust/arrow-flight/src/utils.rs @@ -61,11 +61,8 @@ pub fn flight_schema_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, ) -> SchemaResult { - let data_gen = writer::IpcDataGenerator::default(); - let schema_bytes = data_gen.schema_to_bytes(schema, &options); - SchemaResult { - schema: schema_bytes.ipc_message, + schema: flight_schema_as_flatbuffer(schema, options), } } @@ -74,16 +71,41 @@ pub fn flight_data_from_arrow_schema( schema: &Schema, options: &IpcWriteOptions, ) -> FlightData { - let data_gen = writer::IpcDataGenerator::default(); - let schema = data_gen.schema_to_bytes(schema, &options); + let data_header = flight_schema_as_flatbuffer(schema, options); FlightData { - flight_descriptor: None, - app_metadata: vec![], - data_header: schema.ipc_message, - data_body: vec![], + data_header, + ..Default::default() } } +/// Convert a `Schema` to bytes in the format expected in `FlightInfo.schema` +pub fn ipc_message_from_arrow_schema( + arrow_schema: &Schema, + options: &IpcWriteOptions, +) -> Result> { + let encoded_data = flight_schema_as_encoded_data(arrow_schema, options); + + let mut schema = vec![]; + arrow::ipc::writer::write_message(&mut schema, encoded_data, options)?; + Ok(schema) +} + +fn flight_schema_as_flatbuffer( + arrow_schema: &Schema, + options: &IpcWriteOptions, +) -> Vec { + let encoded_data = flight_schema_as_encoded_data(arrow_schema, options); + encoded_data.ipc_message +} + +fn flight_schema_as_encoded_data( + arrow_schema: &Schema, + options: &IpcWriteOptions, +) -> EncodedData { + let data_gen = writer::IpcDataGenerator::default(); + data_gen.schema_to_bytes(arrow_schema, options) +} + /// Try convert `FlightData` into an Arrow Schema /// /// Returns an error if the `FlightData` header is not a valid IPC schema @@ -132,16 +154,14 @@ pub fn flight_data_to_arrow_batch( "Unable to convert flight data header to a record batch".to_string(), ) }) - .map( - |batch| { - reader::read_record_batch( - &data.data_body, - batch, - schema, - &dictionaries_by_field, - ) - }, - )? + .map(|batch| { + reader::read_record_batch( + &data.data_body, + batch, + schema, + &dictionaries_by_field, + ) + })? } // TODO: add more explicit conversion that exposes flight descriptor and metadata options diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs index 809e7177210..7e6d7962a6e 100644 --- a/rust/arrow/src/ipc/reader.rs +++ b/rust/arrow/src/ipc/reader.rs @@ -473,7 +473,7 @@ pub fn read_record_batch( /// Read the dictionary from the buffer and provided metadata, /// updating the `dictionaries_by_field` with the resulting dictionary -fn read_dictionary( +pub fn read_dictionary( buf: &[u8], batch: ipc::DictionaryBatch, schema: &Schema, diff --git a/rust/integration-testing/Cargo.toml b/rust/integration-testing/Cargo.toml index 528341088fe..4c6d5d02e1a 100644 --- a/rust/integration-testing/Cargo.toml +++ b/rust/integration-testing/Cargo.toml @@ -27,8 +27,14 @@ edition = "2018" [dependencies] arrow = { path = "../arrow" } +arrow-flight = { path = "../arrow-flight" } +async-trait = "0.1.41" clap = "2.33" +futures = "0.3" +hex = "0.4" +prost = "0.6" serde = { version = "1.0", features = ["rc"] } serde_derive = "1.0" serde_json = { version = "1.0", features = ["preserve_order"] } -hex = "0.4" +tokio = { version = "0.2", features = ["macros", "rt-core", "rt-threaded"] } +tonic = "0.3" diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs new file mode 100644 index 00000000000..17352360f85 --- /dev/null +++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_integration_testing::flight_client_scenarios; + +use clap::{App, Arg}; + +type Error = Box; +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result { + let matches = App::new("rust flight-test-integration-client") + .arg(Arg::with_name("host").long("host").takes_value(true)) + .arg(Arg::with_name("port").long("port").takes_value(true)) + .arg(Arg::with_name("path").long("path").takes_value(true)) + .arg( + Arg::with_name("scenario") + .long("scenario") + .takes_value(true), + ) + .get_matches(); + + let host = matches.value_of("host").expect("Host is required"); + let port = matches.value_of("port").expect("Port is required"); + + match matches.value_of("scenario") { + Some("middleware") => { + flight_client_scenarios::middleware::run_scenario(host, port).await? + } + Some("auth:basic_proto") => { + flight_client_scenarios::auth_basic_proto::run_scenario(host, port).await? + } + Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), + None => { + let path = matches + .value_of("path") + .expect("Path is required if scenario is not specified"); + flight_client_scenarios::integration_test::run_scenario(host, port, path) + .await?; + } + } + + Ok(()) +} diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs new file mode 100644 index 00000000000..5ef1253d1ee --- /dev/null +++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use clap::{App, Arg}; + +use arrow_integration_testing::flight_server_scenarios; + +type Error = Box; +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result { + let matches = App::new("rust flight-test-integration-server") + .about("Integration testing server for Flight.") + .arg(Arg::with_name("port").long("port").takes_value(true)) + .arg( + Arg::with_name("scenario") + .long("scenario") + .takes_value(true), + ) + .get_matches(); + + let port = matches.value_of("port").unwrap_or("0"); + + match matches.value_of("scenario") { + Some("middleware") => { + flight_server_scenarios::middleware::scenario_setup(port).await? + } + Some("auth:basic_proto") => { + flight_server_scenarios::auth_basic_proto::scenario_setup(port).await? + } + Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name), + None => { + flight_server_scenarios::integration_test::scenario_setup(port).await?; + } + } + Ok(()) +} diff --git a/rust/integration-testing/src/flight_client_scenarios.rs b/rust/integration-testing/src/flight_client_scenarios.rs new file mode 100644 index 00000000000..66cced5f4c2 --- /dev/null +++ b/rust/integration-testing/src/flight_client_scenarios.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod auth_basic_proto; +pub mod integration_test; +pub mod middleware; diff --git a/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs b/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs new file mode 100644 index 00000000000..5e8cd467198 --- /dev/null +++ b/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{AUTH_PASSWORD, AUTH_USERNAME}; + +use arrow_flight::{ + flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest, +}; +use futures::{stream, StreamExt}; +use prost::Message; +use tonic::{metadata::MetadataValue, Request, Status}; + +type Error = Box; +type Result = std::result::Result; + +type Client = FlightServiceClient; + +pub async fn run_scenario(host: &str, port: &str) -> Result { + let url = format!("http://{}:{}", host, port); + let mut client = FlightServiceClient::connect(url).await?; + + let action = arrow_flight::Action::default(); + + let resp = client.do_action(Request::new(action.clone())).await; + // This client is unauthenticated and should fail. + match resp { + Err(e) => { + if e.code() != tonic::Code::Unauthenticated { + return Err(Box::new(Status::internal(format!( + "Expected UNAUTHENTICATED but got {:?}", + e + )))); + } + } + Ok(other) => { + return Err(Box::new(Status::internal(format!( + "Expected UNAUTHENTICATED but got {:?}", + other + )))); + } + } + + let token = authenticate(&mut client, AUTH_USERNAME, AUTH_PASSWORD) + .await + .expect("must respond successfully from handshake"); + + let mut request = Request::new(action); + let metadata = request.metadata_mut(); + metadata.insert_bin( + "auth-token-bin", + MetadataValue::from_bytes(token.as_bytes()), + ); + + let resp = client.do_action(request).await?; + let mut resp = resp.into_inner(); + + let r = resp + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + + let body = String::from_utf8(r.body).unwrap(); + assert_eq!(body, AUTH_USERNAME); + + Ok(()) +} + +async fn authenticate( + client: &mut Client, + username: &str, + password: &str, +) -> Result { + let auth = BasicAuth { + username: username.into(), + password: password.into(), + }; + let mut payload = vec![]; + auth.encode(&mut payload)?; + + let req = stream::once(async { + HandshakeRequest { + payload, + ..HandshakeRequest::default() + } + }); + + let rx = client.handshake(Request::new(req)).await?; + let mut rx = rx.into_inner(); + + let r = rx.next().await.expect("must respond from handshake")?; + assert!(rx.next().await.is_none(), "must not respond a second time"); + + Ok(String::from_utf8(r.payload).unwrap()) +} diff --git a/rust/integration-testing/src/flight_client_scenarios/integration_test.rs b/rust/integration-testing/src/flight_client_scenarios/integration_test.rs new file mode 100644 index 00000000000..5705f80f82c --- /dev/null +++ b/rust/integration-testing/src/flight_client_scenarios/integration_test.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{read_json_file, ArrowFile}; + +use arrow::{ + array::ArrayRef, + datatypes::SchemaRef, + ipc::{self, reader, writer}, + record_batch::RecordBatch, +}; +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, + utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, Ticket, +}; +use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; +use tonic::{Request, Streaming}; + +use std::sync::Arc; + +type Error = Box; +type Result = std::result::Result; + +type Client = FlightServiceClient; + +pub async fn run_scenario(host: &str, port: &str, path: &str) -> Result { + let url = format!("http://{}:{}", host, port); + + let client = FlightServiceClient::connect(url).await?; + + let ArrowFile { + schema, batches, .. + } = read_json_file(path)?; + + let schema = Arc::new(schema); + + let mut descriptor = FlightDescriptor::default(); + descriptor.set_type(DescriptorType::Path); + descriptor.path = vec![path.to_string()]; + + upload_data( + client.clone(), + schema.clone(), + descriptor.clone(), + batches.clone(), + ) + .await?; + verify_data(client, descriptor, schema, &batches).await?; + + Ok(()) +} + +async fn upload_data( + mut client: Client, + schema: SchemaRef, + descriptor: FlightDescriptor, + original_data: Vec, +) -> Result { + let (mut upload_tx, upload_rx) = mpsc::channel(10); + + let options = arrow::ipc::writer::IpcWriteOptions::default(); + let mut schema_flight_data = + arrow_flight::utils::flight_data_from_arrow_schema(&schema, &options); + schema_flight_data.flight_descriptor = Some(descriptor.clone()); + upload_tx.send(schema_flight_data).await?; + + let mut original_data_iter = original_data.iter().enumerate(); + + if let Some((counter, first_batch)) = original_data_iter.next() { + let metadata = counter.to_string().into_bytes(); + // Preload the first batch into the channel before starting the request + send_batch(&mut upload_tx, &metadata, first_batch, &options).await?; + + let outer = client.do_put(Request::new(upload_rx)).await?; + let mut inner = outer.into_inner(); + + let r = inner + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + + // Stream the rest of the batches + for (counter, batch) in original_data_iter { + let metadata = counter.to_string().into_bytes(); + send_batch(&mut upload_tx, &metadata, batch, &options).await?; + + let r = inner + .next() + .await + .expect("No response received") + .expect("Invalid response received"); + assert_eq!(metadata, r.app_metadata); + } + } else { + drop(upload_tx); + client.do_put(Request::new(upload_rx)).await?; + } + + Ok(()) +} + +async fn send_batch( + upload_tx: &mut mpsc::Sender, + metadata: &[u8], + batch: &RecordBatch, + options: &writer::IpcWriteOptions, +) -> Result { + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); + + upload_tx + .send_all(&mut stream::iter(dictionary_flight_data).map(Ok)) + .await?; + + // Only the record batch's FlightData gets app_metadata + batch_flight_data.app_metadata = metadata.to_vec(); + upload_tx.send(batch_flight_data).await?; + Ok(()) +} + +async fn verify_data( + mut client: Client, + descriptor: FlightDescriptor, + expected_schema: SchemaRef, + expected_data: &[RecordBatch], +) -> Result { + let resp = client.get_flight_info(Request::new(descriptor)).await?; + let info = resp.into_inner(); + + assert!( + !info.endpoint.is_empty(), + "No endpoints returned from Flight server", + ); + for endpoint in info.endpoint { + let ticket = endpoint + .ticket + .expect("No ticket returned from Flight server"); + + assert!( + !endpoint.location.is_empty(), + "No locations returned from Flight server", + ); + for location in endpoint.location { + consume_flight_location( + location, + ticket.clone(), + &expected_data, + expected_schema.clone(), + ) + .await?; + } + } + + Ok(()) +} + +async fn consume_flight_location( + location: Location, + ticket: Ticket, + expected_data: &[RecordBatch], + schema: SchemaRef, +) -> Result { + let mut location = location; + // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs + // don't recognize this as valid. + location.uri = location.uri.replace("grpc+tcp://", "grpc://"); + + let mut client = FlightServiceClient::connect(location.uri).await?; + let resp = client.do_get(ticket).await?; + let mut resp = resp.into_inner(); + + // We already have the schema from the FlightInfo, but the server sends it again as the + // first FlightData. Ignore this one. + let _schema_again = resp.next().await.unwrap(); + + let mut dictionaries_by_field = vec![None; schema.fields().len()]; + + for (counter, expected_batch) in expected_data.iter().enumerate() { + let data = receive_batch_flight_data( + &mut resp, + schema.clone(), + &mut dictionaries_by_field, + ) + .await + .unwrap_or_else(|| { + panic!( + "Got fewer batches than expected, received so far: {} expected: {}", + counter, + expected_data.len(), + ) + }); + + let metadata = counter.to_string().into_bytes(); + assert_eq!(metadata, data.app_metadata); + + let actual_batch = + flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_field) + .expect("Unable to convert flight data to Arrow batch"); + + assert_eq!(expected_batch.schema(), actual_batch.schema()); + assert_eq!(expected_batch.num_columns(), actual_batch.num_columns()); + assert_eq!(expected_batch.num_rows(), actual_batch.num_rows()); + let schema = expected_batch.schema(); + for i in 0..expected_batch.num_columns() { + let field = schema.field(i); + let field_name = field.name(); + + let expected_data = expected_batch.column(i).data(); + let actual_data = actual_batch.column(i).data(); + + assert_eq!(expected_data, actual_data, "Data for field {}", field_name); + } + } + + assert!( + resp.next().await.is_none(), + "Got more batches than the expected: {}", + expected_data.len(), + ); + + Ok(()) +} + +async fn receive_batch_flight_data( + resp: &mut Streaming, + schema: SchemaRef, + dictionaries_by_field: &mut [Option], +) -> Option { + let mut data = resp.next().await?.ok()?; + let mut message = arrow::ipc::root_as_message(&data.data_header[..]) + .expect("Error parsing first message"); + + while message.header_type() == ipc::MessageHeader::DictionaryBatch { + reader::read_dictionary( + &data.data_body, + message + .header_as_dictionary_batch() + .expect("Error parsing dictionary"), + &schema, + dictionaries_by_field, + ) + .expect("Error reading dictionary"); + + data = resp.next().await?.ok()?; + message = arrow::ipc::root_as_message(&data.data_header[..]) + .expect("Error parsing message"); + } + + Some(data) +} diff --git a/rust/integration-testing/src/flight_client_scenarios/middleware.rs b/rust/integration-testing/src/flight_client_scenarios/middleware.rs new file mode 100644 index 00000000000..607eab1018a --- /dev/null +++ b/rust/integration-testing/src/flight_client_scenarios/middleware.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, + FlightDescriptor, +}; +use tonic::{Request, Status}; + +type Error = Box; +type Result = std::result::Result; + +pub async fn run_scenario(host: &str, port: &str) -> Result { + let url = format!("http://{}:{}", host, port); + let conn = tonic::transport::Endpoint::new(url)?.connect().await?; + let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor); + + let mut descriptor = FlightDescriptor::default(); + descriptor.set_type(DescriptorType::Cmd); + descriptor.cmd = b"".to_vec(); + + // This call is expected to fail. + match client + .get_flight_info(Request::new(descriptor.clone())) + .await + { + Ok(_) => return Err(Box::new(Status::internal("Expected call to fail"))), + Err(e) => { + let headers = e.metadata(); + let middleware_header = headers.get("x-middleware"); + let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or(""); + + if value != "expected value" { + let msg = format!( + "On failing call: Expected to receive header 'x-middleware: expected value', \ + but instead got: '{}'", + value + ); + return Err(Box::new(Status::internal(msg))); + } + } + } + + // This call should succeed + descriptor.cmd = b"success".to_vec(); + let resp = client.get_flight_info(Request::new(descriptor)).await?; + + let headers = resp.metadata(); + let middleware_header = headers.get("x-middleware"); + let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or(""); + + if value != "expected value" { + let msg = format!( + "On success call: Expected to receive header 'x-middleware: expected value', \ + but instead got: '{}'", + value + ); + return Err(Box::new(Status::internal(msg))); + } + + Ok(()) +} + +fn middleware_interceptor(mut req: Request<()>) -> Result, Status> { + let metadata = req.metadata_mut(); + metadata.insert("x-middleware", "expected value".parse().unwrap()); + Ok(req) +} diff --git a/rust/integration-testing/src/flight_server_scenarios.rs b/rust/integration-testing/src/flight_server_scenarios.rs new file mode 100644 index 00000000000..3d99e535f76 --- /dev/null +++ b/rust/integration-testing/src/flight_server_scenarios.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::net::SocketAddr; + +use arrow_flight::{FlightEndpoint, Location, Ticket}; +use tokio::net::TcpListener; + +pub mod auth_basic_proto; +pub mod integration_test; +pub mod middleware; + +type Error = Box; +type Result = std::result::Result; + +pub async fn listen_on(port: &str) -> Result<(TcpListener, SocketAddr)> { + let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?; + + let listener = TcpListener::bind(addr).await?; + let addr = listener.local_addr()?; + println!("Server listening on localhost:{}", addr.port()); + + Ok((listener, addr)) +} + +pub fn endpoint(ticket: &str, location_uri: impl Into) -> FlightEndpoint { + FlightEndpoint { + ticket: Some(Ticket { + ticket: ticket.as_bytes().to_vec(), + }), + location: vec![Location { + uri: location_uri.into(), + }], + } +} diff --git a/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs new file mode 100644 index 00000000000..355209f2efb --- /dev/null +++ b/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::pin::Pin; +use std::sync::Arc; + +use arrow_flight::{ + flight_service_server::FlightService, flight_service_server::FlightServiceServer, + Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, + FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, +}; +use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; +use tokio::sync::Mutex; +use tonic::{ + metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming, +}; + +type TonicStream = Pin + Send + Sync + 'static>>; + +type Error = Box; +type Result = std::result::Result; + +use prost::Message; + +use crate::{AUTH_PASSWORD, AUTH_USERNAME}; + +pub async fn scenario_setup(port: &str) -> Result { + let (mut listener, _) = super::listen_on(port).await?; + + let service = AuthBasicProtoScenarioImpl { + username: AUTH_USERNAME.into(), + password: AUTH_PASSWORD.into(), + peer_identity: Arc::new(Mutex::new(None)), + }; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + Ok(()) +} + +#[derive(Clone)] +pub struct AuthBasicProtoScenarioImpl { + username: Arc, + password: Arc, + peer_identity: Arc>>, +} + +impl AuthBasicProtoScenarioImpl { + async fn check_auth( + &self, + metadata: &MetadataMap, + ) -> Result { + let token = metadata + .get_bin("auth-token-bin") + .and_then(|v| v.to_bytes().ok()) + .and_then(|b| String::from_utf8(b.to_vec()).ok()); + self.is_valid(token).await + } + + async fn is_valid( + &self, + token: Option, + ) -> Result { + match token { + Some(t) if t == *self.username => Ok(GrpcServerCallContext { + peer_identity: self.username.to_string(), + }), + _ => Err(Status::unauthenticated("Invalid token")), + } + } +} + +struct GrpcServerCallContext { + peer_identity: String, +} + +impl GrpcServerCallContext { + pub fn peer_identity(&self) -> &str { + &self.peer_identity + } +} + +#[tonic::async_trait] +impl FlightService for AuthBasicProtoScenarioImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn handshake( + &self, + request: Request>, + ) -> Result, Status> { + let (tx, rx) = mpsc::channel(10); + + tokio::spawn({ + let username = self.username.clone(); + let password = self.password.clone(); + + async move { + let requests = request.into_inner(); + + requests + .for_each(move |req| { + let mut tx = tx.clone(); + let req = req.expect("Error reading handshake request"); + let HandshakeRequest { payload, .. } = req; + + let auth = BasicAuth::decode(&*payload) + .expect("Error parsing handshake request"); + + let resp = if *auth.username == *username + && *auth.password == *password + { + Ok(HandshakeResponse { + payload: username.as_bytes().to_vec(), + ..HandshakeResponse::default() + }) + } else { + Err(Status::unauthenticated(format!( + "Don't know user {}", + auth.username + ))) + }; + + async move { + tx.send(resp) + .await + .expect("Error sending handshake response"); + } + }) + .await; + } + }); + + Ok(Response::new(Box::pin(rx))) + } + + async fn list_flights( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_action( + &self, + request: Request, + ) -> Result, Status> { + let flight_context = self.check_auth(request.metadata()).await?; + // Respond with the authenticated username. + let buf = flight_context.peer_identity().as_bytes().to_vec(); + let result = arrow_flight::Result { body: buf }; + let output = futures::stream::once(async { Ok(result) }); + Ok(Response::new(Box::pin(output) as Self::DoActionStream)) + } + + async fn list_actions( + &self, + request: Request, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + request: Request>, + ) -> Result, Status> { + self.check_auth(request.metadata()).await?; + Err(Status::unimplemented("Not yet implemented")) + } +} diff --git a/rust/integration-testing/src/flight_server_scenarios/integration_test.rs b/rust/integration-testing/src/flight_server_scenarios/integration_test.rs new file mode 100644 index 00000000000..a555b2efad9 --- /dev/null +++ b/rust/integration-testing/src/flight_server_scenarios/integration_test.rs @@ -0,0 +1,385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::convert::TryFrom; +use std::pin::Pin; +use std::sync::Arc; + +use arrow::{ + array::ArrayRef, + datatypes::Schema, + datatypes::SchemaRef, + ipc::{self, reader}, + record_batch::RecordBatch, +}; +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_server::FlightService, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, + FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, PutResult, SchemaResult, Ticket, +}; +use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; +use tokio::sync::Mutex; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + +type TonicStream = Pin + Send + Sync + 'static>>; + +type Error = Box; +type Result = std::result::Result; + +pub async fn scenario_setup(port: &str) -> Result { + let (mut listener, addr) = super::listen_on(port).await?; + + let service = FlightServiceImpl { + server_location: format!("grpc+tcp://{}", addr), + ..Default::default() + }; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + + Ok(()) +} + +#[derive(Debug, Clone)] +struct IntegrationDataset { + schema: Schema, + chunks: Vec, +} + +#[derive(Clone, Default)] +pub struct FlightServiceImpl { + server_location: String, + uploaded_chunks: Arc>>, +} + +impl FlightServiceImpl { + fn endpoint_from_path(&self, path: &str) -> FlightEndpoint { + super::endpoint(path, &self.server_location) + } +} + +#[tonic::async_trait] +impl FlightService for FlightServiceImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + let ticket = request.into_inner(); + + let key = String::from_utf8(ticket.ticket.to_vec()) + .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {:?}", e)))?; + + let uploaded_chunks = self.uploaded_chunks.lock().await; + + let flight = uploaded_chunks.get(&key).ok_or_else(|| { + Status::not_found(format!("Could not find flight. {}", key)) + })?; + + let options = arrow::ipc::writer::IpcWriteOptions::default(); + + let schema = std::iter::once({ + Ok(arrow_flight::utils::flight_data_from_arrow_schema( + &flight.schema, + &options, + )) + }); + + let batches = flight + .chunks + .iter() + .enumerate() + .flat_map(|(counter, batch)| { + let (dictionary_flight_data, mut batch_flight_data) = + arrow_flight::utils::flight_data_from_arrow_batch(batch, &options); + + // Only the record batch's FlightData gets app_metadata + let metadata = counter.to_string().into_bytes(); + batch_flight_data.app_metadata = metadata; + + dictionary_flight_data + .into_iter() + .chain(std::iter::once(batch_flight_data)) + .map(Ok) + }); + + let output = futures::stream::iter(schema.chain(batches).collect::>()); + + Ok(Response::new(Box::pin(output) as Self::DoGetStream)) + } + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + let descriptor = request.into_inner(); + + match descriptor.r#type { + t if t == DescriptorType::Path as i32 => { + let path = &descriptor.path; + if path.is_empty() { + return Err(Status::invalid_argument("Invalid path")); + } + + let uploaded_chunks = self.uploaded_chunks.lock().await; + let flight = uploaded_chunks.get(&path[0]).ok_or_else(|| { + Status::not_found(format!("Could not find flight. {}", path[0])) + })?; + + let endpoint = self.endpoint_from_path(&path[0]); + + let total_records: usize = + flight.chunks.iter().map(|chunk| chunk.num_rows()).sum(); + + let options = arrow::ipc::writer::IpcWriteOptions::default(); + let schema = arrow_flight::utils::ipc_message_from_arrow_schema( + &flight.schema, + &options, + ) + .expect( + "Could not generate schema bytes from schema stored by a DoPut; \ + this should be impossible", + ); + + let info = FlightInfo { + schema, + flight_descriptor: Some(descriptor.clone()), + endpoint: vec![endpoint], + total_records: total_records as i64, + total_bytes: -1, + }; + + Ok(Response::new(info)) + } + other => Err(Status::unimplemented(format!("Request type: {}", other))), + } + } + + async fn do_put( + &self, + request: Request>, + ) -> Result, Status> { + let mut input_stream = request.into_inner(); + let flight_data = input_stream + .message() + .await? + .ok_or_else(|| Status::invalid_argument("Must send some FlightData"))?; + + let descriptor = flight_data + .flight_descriptor + .clone() + .ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?; + + if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty() + { + return Err(Status::invalid_argument("Must specify a path")); + } + + let key = descriptor.path[0].clone(); + + let schema = Schema::try_from(&flight_data) + .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?; + let schema_ref = Arc::new(schema.clone()); + + let (response_tx, response_rx) = mpsc::channel(10); + + let uploaded_chunks = self.uploaded_chunks.clone(); + + tokio::spawn(async { + let mut error_tx = response_tx.clone(); + if let Err(e) = save_uploaded_chunks( + uploaded_chunks, + schema_ref, + input_stream, + response_tx, + schema, + key, + ) + .await + { + error_tx.send(Err(e)).await.expect("Error sending error") + } + }); + + Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream)) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} + +async fn send_app_metadata( + tx: &mut mpsc::Sender>, + app_metadata: &[u8], +) -> Result<(), Status> { + tx.send(Ok(PutResult { + app_metadata: app_metadata.to_vec(), + })) + .await + .map_err(|e| Status::internal(format!("Could not send PutResult: {:?}", e))) +} + +async fn record_batch_from_message( + message: ipc::Message<'_>, + data_body: &[u8], + schema_ref: SchemaRef, + dictionaries_by_field: &[Option], +) -> Result { + let ipc_batch = message.header_as_record_batch().ok_or_else(|| { + Status::internal("Could not parse message header as record batch") + })?; + + let arrow_batch_result = reader::read_record_batch( + data_body, + ipc_batch, + schema_ref, + &dictionaries_by_field, + ); + + arrow_batch_result.map_err(|e| { + Status::internal(format!("Could not convert to RecordBatch: {:?}", e)) + }) +} + +async fn dictionary_from_message( + message: ipc::Message<'_>, + data_body: &[u8], + schema_ref: SchemaRef, + dictionaries_by_field: &mut [Option], +) -> Result<(), Status> { + let ipc_batch = message.header_as_dictionary_batch().ok_or_else(|| { + Status::internal("Could not parse message header as dictionary batch") + })?; + + let dictionary_batch_result = + reader::read_dictionary(data_body, ipc_batch, &schema_ref, dictionaries_by_field); + dictionary_batch_result.map_err(|e| { + Status::internal(format!("Could not convert to Dictionary: {:?}", e)) + }) +} + +async fn save_uploaded_chunks( + uploaded_chunks: Arc>>, + schema_ref: Arc, + mut input_stream: Streaming, + mut response_tx: mpsc::Sender>, + schema: Schema, + key: String, +) -> Result<(), Status> { + let mut chunks = vec![]; + let mut uploaded_chunks = uploaded_chunks.lock().await; + + let mut dictionaries_by_field = vec![None; schema_ref.fields().len()]; + + while let Some(Ok(data)) = input_stream.next().await { + let message = arrow::ipc::root_as_message(&data.data_header[..]) + .map_err(|e| Status::internal(format!("Could not parse message: {:?}", e)))?; + + match message.header_type() { + ipc::MessageHeader::Schema => { + return Err(Status::internal( + "Not expecting a schema when messages are read", + )) + } + ipc::MessageHeader::RecordBatch => { + send_app_metadata(&mut response_tx, &data.app_metadata).await?; + + let batch = record_batch_from_message( + message, + &data.data_body, + schema_ref.clone(), + &dictionaries_by_field, + ) + .await?; + + chunks.push(batch); + } + ipc::MessageHeader::DictionaryBatch => { + dictionary_from_message( + message, + &data.data_body, + schema_ref.clone(), + &mut dictionaries_by_field, + ) + .await?; + } + t => { + return Err(Status::internal(format!( + "Reading types other than record batches not yet supported, \ + unable to read {:?}", + t + ))); + } + } + } + + let dataset = IntegrationDataset { schema, chunks }; + uploaded_chunks.insert(key, dataset); + + Ok(()) +} diff --git a/rust/integration-testing/src/flight_server_scenarios/middleware.rs b/rust/integration-testing/src/flight_server_scenarios/middleware.rs new file mode 100644 index 00000000000..12421bc8928 --- /dev/null +++ b/rust/integration-testing/src/flight_server_scenarios/middleware.rs @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::pin::Pin; + +use arrow_flight::{ + flight_descriptor::DescriptorType, flight_service_server::FlightService, + flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, + FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, + PutResult, SchemaResult, Ticket, +}; +use futures::Stream; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + +type TonicStream = Pin + Send + Sync + 'static>>; + +type Error = Box; +type Result = std::result::Result; + +pub async fn scenario_setup(port: &str) -> Result { + let (mut listener, _) = super::listen_on(port).await?; + + let service = MiddlewareScenarioImpl {}; + let svc = FlightServiceServer::new(service); + + Server::builder() + .add_service(svc) + .serve_with_incoming(listener.incoming()) + .await?; + Ok(()) +} + +#[derive(Clone, Default)] +pub struct MiddlewareScenarioImpl {} + +#[tonic::async_trait] +impl FlightService for MiddlewareScenarioImpl { + type HandshakeStream = TonicStream>; + type ListFlightsStream = TonicStream>; + type DoGetStream = TonicStream>; + type DoPutStream = TonicStream>; + type DoActionStream = TonicStream>; + type ListActionsStream = TonicStream>; + type DoExchangeStream = TonicStream>; + + async fn get_schema( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_get( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn handshake( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_flights( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> Result, Status> { + let middleware_header = request.metadata().get("x-middleware").cloned(); + + let descriptor = request.into_inner(); + + if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd == b"success" + { + // Return a fake location - the test doesn't read it + let endpoint = super::endpoint("foo", "grpc+tcp://localhost:10010"); + + let info = FlightInfo { + flight_descriptor: Some(descriptor), + endpoint: vec![endpoint], + ..Default::default() + }; + + let mut response = Response::new(info); + if let Some(value) = middleware_header { + response.metadata_mut().insert("x-middleware", value); + } + + return Ok(response); + } + + let mut status = Status::unknown("Unknown"); + if let Some(value) = middleware_header { + status.metadata_mut().insert("x-middleware", value); + } + + Err(status) + } + + async fn do_put( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_action( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn list_actions( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } + + async fn do_exchange( + &self, + _request: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } +} diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/lib.rs index b93f1c4aa51..93bc8d55626 100644 --- a/rust/integration-testing/src/lib.rs +++ b/rust/integration-testing/src/lib.rs @@ -38,6 +38,14 @@ use std::fs::File; use std::io::BufReader; use std::sync::Arc; +/// The expected username for the basic auth integration test. +pub const AUTH_USERNAME: &str = "arrow"; +/// The expected password for the basic auth integration test. +pub const AUTH_PASSWORD: &str = "flight"; + +pub mod flight_client_scenarios; +pub mod flight_server_scenarios; + pub struct ArrowFile { pub schema: Schema, // we can evolve this into a concrete Arrow type