diff --git a/rust/benchmarks/README.md b/rust/benchmarks/README.md index 9bff3e2d8ee..2ae035b9fc4 100644 --- a/rust/benchmarks/README.md +++ b/rust/benchmarks/README.md @@ -37,6 +37,7 @@ clone the repository and build the source code. git clone git@github.com:databricks/tpch-dbgen.git cd tpch-dbgen make +export TPCH_DATA=$(pwd) ``` Data can now be generated with the following command. Note that `-s 1` means use Scale Factor 1 or ~1 GB of @@ -63,7 +64,7 @@ This utility does not yet provide support for changing the number of partitions option is to use the following Docker image to perform the conversion from `tbl` files to CSV or Parquet. ```bash -docker run -it ballistacompute/spark-benchmarks:0.4.0-SNAPSHOT +docker run -it ballistacompute/spark-benchmarks:0.4.0-SNAPSHOT -h, --help Show help message Subcommand: convert-tpch diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs index 8a6fca78930..a13db9d6577 100644 --- a/rust/benchmarks/src/bin/tpch.rs +++ b/rust/benchmarks/src/bin/tpch.rs @@ -104,15 +104,14 @@ const TABLES: &[&str] = &[ #[tokio::main] async fn main() -> Result<()> { + env_logger::init(); match TpchOpt::from_args() { - TpchOpt::Benchmark(opt) => benchmark(opt).await, + TpchOpt::Benchmark(opt) => benchmark(opt).await.map(|_| ()), TpchOpt::Convert(opt) => convert_tbl(opt).await, } } -async fn benchmark(opt: BenchmarkOpt) -> Result<()> { - env_logger::init(); - +async fn benchmark(opt: BenchmarkOpt) -> Result> { println!("Running benchmarks with the following options: {:?}", opt); let config = ExecutionConfig::new() .with_concurrency(opt.concurrency) @@ -142,10 +141,11 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<()> { let mut millis = vec![]; // run benchmark + let mut result: Vec = Vec::with_capacity(1); for i in 0..opt.iterations { let start = Instant::now(); let plan = create_logical_plan(&mut ctx, opt.query)?; - execute_query(&mut ctx, &plan, opt.debug).await?; + result = execute_query(&mut ctx, &plan, opt.debug).await?; let elapsed = start.elapsed().as_secs_f64() * 1000.0; millis.push(elapsed as f64); println!("Query {} iteration {} took {:.1} ms", opt.query, i, elapsed); @@ -154,7 +154,7 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<()> { let avg = millis.iter().sum::() / millis.len() as f64; println!("Query {} avg time: {:.2} ms", opt.query, avg); - Ok(()) + Ok(result) } fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result { @@ -990,7 +990,7 @@ async fn execute_query( ctx: &mut ExecutionContext, plan: &LogicalPlan, debug: bool, -) -> Result<()> { +) -> Result> { if debug { println!("Logical plan:\n{:?}", plan); } @@ -1003,12 +1003,11 @@ async fn execute_query( if debug { pretty::print_batches(&result)?; } - Ok(()) + Ok(result) } async fn convert_tbl(opt: ConvertOpt) -> Result<()> { let output_root_path = Path::new(&opt.output_path); - for table in TABLES { let start = Instant::now(); let schema = get_schema(table); @@ -1083,13 +1082,14 @@ fn get_table( table_format: &str, ) -> Result> { match table_format { - // dbgen creates .tbl ('|' delimited) files + // dbgen creates .tbl ('|' delimited) files without header "tbl" => { let path = format!("{}/{}.tbl", path, table); let schema = get_schema(table); let options = CsvReadOptions::new() .schema(&schema) .delimiter(b'|') + .has_header(false) .file_extension(".tbl"); Ok(Box::new(CsvFile::try_new(&path, options)?)) @@ -1125,7 +1125,7 @@ fn get_schema(table: &str) -> Schema { Field::new("p_type", DataType::Utf8, false), Field::new("p_size", DataType::Int32, false), Field::new("p_container", DataType::Utf8, false), - Field::new("p_retailprice", DataType::Float64, false), // decimal + Field::new("p_retailprice", DataType::Float64, false), Field::new("p_comment", DataType::Utf8, false), ]), @@ -1135,7 +1135,7 @@ fn get_schema(table: &str) -> Schema { Field::new("s_address", DataType::Utf8, false), Field::new("s_nationkey", DataType::Int32, false), Field::new("s_phone", DataType::Utf8, false), - Field::new("s_acctbal", DataType::Float64, false), // decimal + Field::new("s_acctbal", DataType::Float64, false), Field::new("s_comment", DataType::Utf8, false), ]), @@ -1143,7 +1143,7 @@ fn get_schema(table: &str) -> Schema { Field::new("ps_partkey", DataType::Int32, false), Field::new("ps_suppkey", DataType::Int32, false), Field::new("ps_availqty", DataType::Int32, false), - Field::new("ps_supplycost", DataType::Float64, false), // decimal + Field::new("ps_supplycost", DataType::Float64, false), Field::new("ps_comment", DataType::Utf8, false), ]), @@ -1153,7 +1153,7 @@ fn get_schema(table: &str) -> Schema { Field::new("c_address", DataType::Utf8, false), Field::new("c_nationkey", DataType::Int32, false), Field::new("c_phone", DataType::Utf8, false), - Field::new("c_acctbal", DataType::Float64, false), // decimal + Field::new("c_acctbal", DataType::Float64, false), Field::new("c_mktsegment", DataType::Utf8, false), Field::new("c_comment", DataType::Utf8, false), ]), @@ -1162,7 +1162,7 @@ fn get_schema(table: &str) -> Schema { Field::new("o_orderkey", DataType::Int32, false), Field::new("o_custkey", DataType::Int32, false), Field::new("o_orderstatus", DataType::Utf8, false), - Field::new("o_totalprice", DataType::Float64, false), // decimal + Field::new("o_totalprice", DataType::Float64, false), Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false), Field::new("o_orderpriority", DataType::Utf8, false), Field::new("o_clerk", DataType::Utf8, false), @@ -1175,10 +1175,10 @@ fn get_schema(table: &str) -> Schema { Field::new("l_partkey", DataType::Int32, false), Field::new("l_suppkey", DataType::Int32, false), Field::new("l_linenumber", DataType::Int32, false), - Field::new("l_quantity", DataType::Float64, false), // decimal - Field::new("l_extendedprice", DataType::Float64, false), // decimal - Field::new("l_discount", DataType::Float64, false), // decimal - Field::new("l_tax", DataType::Float64, false), // decimal + Field::new("l_quantity", DataType::Float64, false), + Field::new("l_extendedprice", DataType::Float64, false), + Field::new("l_discount", DataType::Float64, false), + Field::new("l_tax", DataType::Float64, false), Field::new("l_returnflag", DataType::Utf8, false), Field::new("l_linestatus", DataType::Utf8, false), Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false), @@ -1205,3 +1205,404 @@ fn get_schema(table: &str) -> Schema { _ => unimplemented!(), } } + +#[cfg(test)] +mod tests { + use super::*; + use std::env; + use std::sync::Arc; + + use arrow::array::*; + use arrow::record_batch::RecordBatch; + use arrow::util::display::array_value_to_string; + + use datafusion::logical_plan::Expr; + use datafusion::logical_plan::Expr::Cast; + + #[tokio::test] + async fn q1() -> Result<()> { + verify_query(1).await + } + + #[tokio::test] + async fn q2() -> Result<()> { + verify_query(2).await + } + + #[tokio::test] + async fn q3() -> Result<()> { + verify_query(3).await + } + + #[tokio::test] + async fn q4() -> Result<()> { + verify_query(4).await + } + + #[tokio::test] + async fn q5() -> Result<()> { + verify_query(5).await + } + + #[tokio::test] + async fn q6() -> Result<()> { + verify_query(6).await + } + + #[tokio::test] + async fn q7() -> Result<()> { + verify_query(7).await + } + + #[tokio::test] + async fn q8() -> Result<()> { + verify_query(8).await + } + + #[tokio::test] + async fn q9() -> Result<()> { + verify_query(9).await + } + + #[tokio::test] + async fn q10() -> Result<()> { + verify_query(10).await + } + + #[tokio::test] + async fn q11() -> Result<()> { + verify_query(11).await + } + + #[tokio::test] + async fn q12() -> Result<()> { + verify_query(12).await + } + + #[tokio::test] + async fn q13() -> Result<()> { + verify_query(13).await + } + + #[tokio::test] + async fn q14() -> Result<()> { + verify_query(14).await + } + + #[tokio::test] + async fn q15() -> Result<()> { + verify_query(15).await + } + + #[tokio::test] + async fn q16() -> Result<()> { + verify_query(16).await + } + + #[tokio::test] + async fn q17() -> Result<()> { + verify_query(17).await + } + + #[tokio::test] + async fn q18() -> Result<()> { + verify_query(18).await + } + + #[tokio::test] + async fn q19() -> Result<()> { + verify_query(19).await + } + + #[tokio::test] + async fn q20() -> Result<()> { + verify_query(20).await + } + + #[tokio::test] + async fn q21() -> Result<()> { + verify_query(21).await + } + + #[tokio::test] + async fn q22() -> Result<()> { + verify_query(22).await + } + + /// Specialised String representation + fn col_str(column: &ArrayRef, row_index: usize) -> String { + if column.is_null(row_index) { + return "NULL".to_string(); + } + + // Special case ListArray as there is no pretty print support for it yet + if let DataType::FixedSizeList(_, n) = column.data_type() { + let array = column + .as_any() + .downcast_ref::() + .unwrap() + .value(row_index); + + let mut r = Vec::with_capacity(*n as usize); + for i in 0..*n { + r.push(col_str(&array, i as usize)); + } + return format!("[{}]", r.join(",")); + } + + array_value_to_string(column, row_index).unwrap() + } + + /// Converts the results into a 2d array of strings, `result[row][column]` + /// Special cases nulls to NULL for testing + fn result_vec(results: &[RecordBatch]) -> Vec> { + let mut result = vec![]; + for batch in results { + for row_index in 0..batch.num_rows() { + let row_vec = batch + .columns() + .iter() + .map(|column| col_str(column, row_index)) + .collect(); + result.push(row_vec); + } + } + result + } + + fn get_answer_schema(n: usize) -> Schema { + match n { + 1 => Schema::new(vec![ + Field::new("l_returnflag", DataType::Utf8, true), + Field::new("l_linestatus", DataType::Utf8, true), + Field::new("sum_qty", DataType::Float64, true), + Field::new("sum_base_price", DataType::Float64, true), + Field::new("sum_disc_price", DataType::Float64, true), + Field::new("sum_charge", DataType::Float64, true), + Field::new("avg_qty", DataType::Float64, true), + Field::new("avg_price", DataType::Float64, true), + Field::new("avg_disc", DataType::Float64, true), + Field::new("count_order", DataType::UInt64, true), + ]), + + 2 => Schema::new(vec![ + Field::new("s_acctbal", DataType::Float64, true), + Field::new("s_name", DataType::Utf8, true), + Field::new("n_name", DataType::Utf8, true), + Field::new("p_partkey", DataType::Int32, true), + Field::new("p_mfgr", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + Field::new("s_phone", DataType::Utf8, true), + Field::new("s_comment", DataType::Utf8, true), + ]), + + 3 => Schema::new(vec![ + Field::new("l_orderkey", DataType::Int32, true), + Field::new("revenue", DataType::Float64, true), + Field::new("o_orderdat", DataType::Date32(DateUnit::Day), true), + Field::new("o_shippriority", DataType::Int32, true), + ]), + + 4 => Schema::new(vec![ + Field::new("o_orderpriority", DataType::Utf8, true), + Field::new("order_count", DataType::Int32, true), + ]), + + 5 => Schema::new(vec![ + Field::new("n_name", DataType::Utf8, true), + Field::new("revenue", DataType::Float64, true), + ]), + + 6 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]), + + 7 => Schema::new(vec![ + Field::new("supp_nation", DataType::Utf8, true), + Field::new("cust_nation", DataType::Utf8, true), + Field::new("l_year", DataType::Int32, true), + Field::new("revenue", DataType::Float64, true), + ]), + + 8 => Schema::new(vec![ + Field::new("o_year", DataType::Int32, true), + Field::new("mkt_share", DataType::Float64, true), + ]), + + 9 => Schema::new(vec![ + Field::new("nation", DataType::Utf8, true), + Field::new("o_year", DataType::Int32, true), + Field::new("sum_profit", DataType::Float64, true), + ]), + + 10 => Schema::new(vec![ + Field::new("c_custkey", DataType::Int32, true), + Field::new("c_name", DataType::Utf8, true), + Field::new("revenue", DataType::Float64, true), + Field::new("c_acctbal", DataType::Float64, true), + Field::new("n_name", DataType::Utf8, true), + Field::new("c_address", DataType::Utf8, true), + Field::new("c_phone", DataType::Utf8, true), + Field::new("c_comment", DataType::Utf8, true), + ]), + + 11 => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int32, true), + Field::new("value", DataType::Float64, true), + ]), + + 12 => Schema::new(vec![ + Field::new("l_shipmode", DataType::Utf8, true), + Field::new("high_line_count", DataType::Int64, true), + Field::new("low_line_count", DataType::Int64, true), + ]), + + 13 => Schema::new(vec![ + Field::new("c_count", DataType::Int64, true), + Field::new("custdist", DataType::Int64, true), + ]), + + 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), + + 15 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]), + + 16 => Schema::new(vec![ + Field::new("p_brand", DataType::Utf8, true), + Field::new("p_type", DataType::Utf8, true), + Field::new("c_phone", DataType::Int32, true), + Field::new("c_comment", DataType::Int32, true), + ]), + + 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]), + + 18 => Schema::new(vec![ + Field::new("c_name", DataType::Utf8, true), + Field::new("c_custkey", DataType::Int32, true), + Field::new("o_orderkey", DataType::Int32, true), + Field::new("o_orderdat", DataType::Date32(DateUnit::Day), true), + Field::new("o_totalprice", DataType::Float64, true), + Field::new("sum_l_quantity", DataType::Float64, true), + ]), + + 19 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]), + + 20 => Schema::new(vec![ + Field::new("s_name", DataType::Utf8, true), + Field::new("s_address", DataType::Utf8, true), + ]), + + 21 => Schema::new(vec![ + Field::new("s_name", DataType::Utf8, true), + Field::new("numwait", DataType::Int32, true), + ]), + + 22 => Schema::new(vec![ + Field::new("cntrycode", DataType::Int32, true), + Field::new("numcust", DataType::Int32, true), + Field::new("totacctbal", DataType::Float64, true), + ]), + + _ => unimplemented!(), + } + } + + // convert expected schema to all utf8 so columns can be read as strings to be parsed separately + // this is due to the fact that the csv parser cannot handle leading/trailing spaces + fn string_schema(schema: Schema) -> Schema { + Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + Field::name(&field), + DataType::Utf8, + Field::is_nullable(&field), + ) + }) + .collect::>(), + ) + } + + // convert the schema to the same but with all columns set to nullable=true. + // this allows direct schema comparison ignoring nullable. + fn nullable_schema(schema: Arc) -> Schema { + Schema::new( + schema + .fields() + .iter() + .map(|field| { + Field::new( + Field::name(&field), + Field::data_type(&field).to_owned(), + true, + ) + }) + .collect::>(), + ) + } + + async fn verify_query(n: usize) -> Result<()> { + if let Ok(path) = env::var("TPCH_DATA") { + // load expected answers from tpch-dbgen + // read csv as all strings, trim and cast to expected type as the csv string + // to value parser does not handle data with leading/trailing spaces + let mut ctx = ExecutionContext::new(); + let schema = string_schema(get_answer_schema(n)); + let options = CsvReadOptions::new() + .schema(&schema) + .delimiter(b'|') + .file_extension(".out"); + let df = ctx.read_csv(&format!("{}/answers/q{}.out", path, n), options)?; + let df = df.select( + get_answer_schema(n) + .fields() + .iter() + .map(|field| { + Expr::Alias( + Box::new(Cast { + expr: Box::new(trim(col(Field::name(&field)))), + data_type: Field::data_type(&field).to_owned(), + }), + Field::name(&field).to_string(), + ) + }) + .collect::>(), + )?; + let expected = df.collect().await?; + + // run the query to compute actual results of the query + let opt = BenchmarkOpt { + query: n, + debug: false, + iterations: 1, + concurrency: 2, + batch_size: 4096, + path: PathBuf::from(path.to_string()), + file_format: "tbl".to_string(), + mem_table: false, + }; + let actual = benchmark(opt).await?; + + // assert schema equality without comparing nullable values + assert_eq!( + nullable_schema(expected[0].schema()), + nullable_schema(actual[0].schema()) + ); + + // convert both datasets to Vec> for simple comparison + let expected_vec = result_vec(&expected); + let actual_vec = result_vec(&actual); + + // basic result comparison + assert_eq!(expected_vec.len(), actual_vec.len()); + + // compare each row. this works as all TPC-H queries have determinisically ordered results + for i in 0..actual_vec.len() { + assert_eq!(expected_vec[i], actual_vec[i]); + } + } else { + println!("TPCH_DATA environment variable not set, skipping test"); + } + + Ok(()) + } +} diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 309b75bc6b1..c8a4804ae6c 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -28,7 +28,7 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::logical_plan::{ - array, avg, col, concat, count, create_udf, length, lit, max, min, sum, JoinType, - Partitioning, + array, avg, col, concat, count, create_udf, length, lit, lower, max, min, sum, trim, + upper, JoinType, Partitioning, }; pub use crate::physical_plan::csv::CsvReadOptions;