From 3cf18fbd7e4f11a2362ac425813eb74240b68b52 Mon Sep 17 00:00:00 2001 From: gaojun Date: Thu, 20 Jan 2022 13:26:46 +0800 Subject: [PATCH 01/38] =?UTF-8?q?=E6=B7=BB=E5=8A=A0UDAF=E7=9A=84=E5=BA=8F?= =?UTF-8?q?=E5=88=97=E5=8C=96=E5=8F=8D=E5=BA=8F=E5=88=97=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/client/Cargo.toml | 8 +++---- ballista/rust/core/Cargo.toml | 3 ++- ballista/rust/core/proto/ballista.proto | 12 ++++++++++ .../core/src/serde/logical_plan/from_proto.rs | 23 +++++++++++++++++++ .../core/src/serde/logical_plan/to_proto.rs | 21 ++++++++++++++++- ballista/rust/executor/Cargo.toml | 4 ++-- ballista/rust/scheduler/Cargo.toml | 6 ++--- 7 files changed, 66 insertions(+), 11 deletions(-) diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index 7736e949d29f5..a9471ccb0aad3 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -27,14 +27,14 @@ edition = "2021" rust-version = "1.57" [dependencies] -ballista-core = { path = "../core", version = "0.6.0" } -ballista-executor = { path = "../executor", version = "0.6.0", optional = true } -ballista-scheduler = { path = "../scheduler", version = "0.6.0", optional = true } +ballista-core = { path = "../core"} +ballista-executor = { path = "../executor", optional = true } +ballista-scheduler = { path = "../scheduler", optional = true } futures = "0.3" log = "0.4" tokio = "1.0" -datafusion = { path = "../../../datafusion", version = "6.0.0" } +datafusion = { path = "../../../datafusion"} [features] default = [] diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index bbf8e274c5cd8..fc1a550ed902c 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -45,7 +45,8 @@ chrono = { version = "0.4", default-features = false } arrow-flight = { version = "7.0.0" } -datafusion = { path = "../../../datafusion", version = "6.0.0" } +datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +argo-engine-common = { git = "http://git.analysysdata.com/noah/argo_engine.git", branch="master", package = "common" } [dev-dependencies] tempfile = "3" diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index cc6e00aa939f1..20049fab29c58 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -80,6 +80,10 @@ message LogicalExprNode { // window expressions WindowExprNode window_expr = 18; + + //ArgoEngineAggregateUDF expressions + AggregateUDFExprNode aggregate_udf_expr = 19; + } } @@ -976,6 +980,14 @@ service SchedulerGrpc { rpc GetJobStatus (GetJobStatusParams) returns (GetJobStatusResult) {} } +/////////////////////////////////////////////////////////////////////////////////////////////////// +// ArgoEngine add. +/////////////////////////////////////////////////////////////////////////////////////////////////// +message AggregateUDFExprNode { + string fun_name = 1; + repeated LogicalExprNode args = 2; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // Arrow Data Types /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index dfac547d7bb35..406b737d48837 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -17,6 +17,7 @@ //! Serde code to convert from protocol buffers to Rust data structures. +use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; use crate::error::BallistaError; use crate::serde::{from_proto_binary_op, proto_error, protobuf, str_to_byte}; use crate::{convert_box_required, convert_required}; @@ -966,6 +967,26 @@ impl TryInto for &protobuf::LogicalExprNode { distinct: false, //TODO }) } + // argo engine add start + ExprType::AggregateUdfExpr(expr) => { + let fun = from_name_to_udaf(expr.fun_name.as_str()).map_err(|e| { + proto_error(format!( + "from_proto error: {}", + e + )) + })?; + let fun_arc = Arc::new(fun); + let fun_args= &expr.args; + let args: Vec = fun_args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(Expr::AggregateUDF { + fun: fun_arc, + args: args.try_into().unwrap(), + }) + } + // argo engine add end ExprType::Alias(alias) => Ok(Expr::Alias( Box::new(parse_required_expr(&alias.expr)?), alias.alias.clone(), @@ -1174,6 +1195,8 @@ use datafusion::prelude::{ sha384, sha512, trim, upper, }; use std::convert::TryFrom; +use futures::TryFutureExt; +use datafusion::physical_plan::udaf::AggregateUDF; impl TryFrom for protobuf::FileType { type Error = BallistaError; diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 36e9ba69ed5ad..39f9afe737708 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -54,6 +54,8 @@ use std::{ boxed, convert::{TryFrom, TryInto}, }; +use std::sync::Arc; +use datafusion::physical_plan::udaf::AggregateUDF; impl protobuf::IntervalUnit { pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self { @@ -1079,7 +1081,24 @@ impl TryInto for &Expr { }) } Expr::ScalarUDF { .. } => unimplemented!(), - Expr::AggregateUDF { .. } => unimplemented!(), + // argo engine add start + Expr::AggregateUDF { ref fun, ref args } => { + let args: Vec = args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(protobuf::LogicalExprNode { + expr_type: Some( + protobuf::logical_expr_node::ExprType::AggregateUdfExpr( + protobuf::AggregateUdfExprNode { + fun_name: fun.name.clone(), + args, + }, + ), + ), + }) + } + // argo engine add end Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index c01bb20681dbd..d8c2aafd2b518 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -33,9 +33,9 @@ arrow = { version = "7.0.0" } arrow-flight = { version = "7.0.0" } anyhow = "1" async-trait = "0.1.36" -ballista-core = { path = "../core", version = "0.6.0" } +ballista-core = { path = "../core"} configure_me = "0.4.0" -datafusion = { path = "../../../datafusion", version = "6.0.0" } +datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } env_logger = "0.9" futures = "0.3" log = "0.4" diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index 0bacccf031d8c..3cae3f7fcad0a 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -32,10 +32,10 @@ sled = ["sled_package", "tokio-stream"] [dependencies] anyhow = "1" -ballista-core = { path = "../core", version = "0.6.0" } +ballista-core = { path = "../core"} clap = "2" configure_me = "0.4.0" -datafusion = { path = "../../../datafusion", version = "6.0.0" } +datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } env_logger = "0.9" etcd-client = { version = "0.7", optional = true } futures = "0.3" @@ -55,7 +55,7 @@ tower = { version = "0.4" } warp = "0.3" [dev-dependencies] -ballista-core = { path = "../core", version = "0.6.0" } +ballista-core = { path = "../core"} uuid = { version = "0.8", features = ["v4"] } [build-dependencies] From 37f14173258957366c808346fee3aa644379799a Mon Sep 17 00:00:00 2001 From: gaojun Date: Fri, 21 Jan 2022 10:40:24 +0800 Subject: [PATCH 02/38] =?UTF-8?q?=E6=B7=BB=E5=8A=A0udaf=E7=9A=84=E7=89=A9?= =?UTF-8?q?=E7=90=86=E6=89=A7=E8=A1=8C=E8=AE=A1=E7=AE=97=E5=BA=8F=E5=88=97?= =?UTF-8?q?=E5=8C=96=E5=8F=8D=E5=BA=8F=E5=88=97=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/core/proto/ballista.proto | 11 ++++ .../src/serde/physical_plan/from_proto.rs | 23 +++++++ .../core/src/serde/physical_plan/to_proto.rs | 63 ++++++++++++------- 3 files changed, 73 insertions(+), 24 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 20049fab29c58..b94bff4bed111 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -521,6 +521,10 @@ message PhysicalExprNode { // window expressions PhysicalWindowExprNode window_expr = 15; + + // argo engine add. + PhysicalAggregateUDFExprNode aggregate_udf_expr = 16; + // argo engine add end. } } @@ -529,6 +533,13 @@ message PhysicalAggregateExprNode { PhysicalExprNode expr = 2; } +// argo engine add. +message PhysicalAggregateUDFExprNode { + string fun_name = 1; + PhysicalExprNode expr = 2; +} +// argo engine add end. + message PhysicalWindowExprNode { oneof window_function { AggregateFunction aggr_function = 1; diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index cad27b3156450..854701a2cbbba 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -17,6 +17,7 @@ //! Serde code to convert from protocol buffers to Rust data structures. +use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; @@ -55,6 +56,7 @@ use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::sorts::sort::{SortExec, SortOptions}; +use datafusion::physical_plan::udaf::create_aggregate_expr as create_aggregate_udf_expr; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -311,6 +313,21 @@ impl TryInto> for &protobuf::PhysicalPlanNode { name.to_string(), )?) } + ExprType::AggregateUdfExpr(agg_node) => { + let name = agg_node.fun_name.as_str(); + let fun = from_name_to_udaf(name).map_err(|e| { + proto_error(format!( + "from_proto error: {}", + e + )) + })?; + Ok(create_aggregate_udf_expr( + &fun, + &[convert_box_required!(agg_node.expr)?], + &physical_schema, + name.to_string(), + )?) + } _ => Err(BallistaError::General( "Invalid aggregate expression for HashAggregateExec" .to_string(), @@ -545,6 +562,12 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { .to_owned(), )); } + ExprType::AggregateUdfExpr(_) => { + return Err(BallistaError::General( + "Cannot convert aggregate udf expr node to physical expression" + .to_owned(), + )); + } ExprType::WindowExpr(_) => { return Err(BallistaError::General( "Cannot convert window expr node to physical expression".to_owned(), diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 930f0757e2020..78a50d78bce88 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -74,6 +74,7 @@ use crate::{ use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::functions::{BuiltinScalarFunction, ScalarFunctionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; +use datafusion::physical_plan::udaf::AggregateFunctionExpr; impl TryInto for Arc { type Error = BallistaError; @@ -412,35 +413,49 @@ impl TryInto for Arc { type Error = BallistaError; fn try_into(self) -> Result { - let aggr_function = if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Avg.into()) - } else if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Sum.into()) - } else if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Count.into()) - } else if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Min.into()) - } else if self.as_any().downcast_ref::().is_some() { - Ok(protobuf::AggregateFunction::Max.into()) - } else { - Err(BallistaError::NotImplemented(format!( - "Aggregate function not supported: {:?}", - self - ))) - }?; + // argo engine add. + // aggregate udf let expressions: Vec = self .expressions() .iter() .map(|e| e.clone().try_into()) .collect::, BallistaError>>()?; - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( - Box::new(protobuf::PhysicalAggregateExprNode { - aggr_function, - expr: Some(Box::new(expressions[0].clone())), - }), - )), - }) + if self.as_any().downcast_ref::().is_some() { + let name = self.name(); + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateUdfExpr( + Box::new(protobuf::PhysicalAggregateUdfExprNode { + fun_name: name.to_string(), + expr: Some(Box::new(expressions[0].clone())), + }), + )), + }) + } else { + let aggr_function = if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Avg.into()) + } else if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Sum.into()) + } else if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Count.into()) + } else if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Min.into()) + } else if self.as_any().downcast_ref::().is_some() { + Ok(protobuf::AggregateFunction::Max.into()) + } else { + Err(BallistaError::NotImplemented(format!( + "Aggregate function not supported: {:?}", + self + ))) + }?; + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( + Box::new(protobuf::PhysicalAggregateExprNode { + aggr_function, + expr: Some(Box::new(expressions[0].clone())), + }), + )), + }) + } } } From 350bee14ed27e48b992fefd3a309fb914b3851ed Mon Sep 17 00:00:00 2001 From: gaojun Date: Fri, 21 Jan 2022 16:24:13 +0800 Subject: [PATCH 03/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=8A=E6=B8=B8task?= =?UTF-8?q?=E5=A4=B1=E8=B4=A5=EF=BC=8C=E4=B8=8B=E6=B8=B8task=E4=B8=80?= =?UTF-8?q?=E7=9B=B4=E5=8D=A1=E4=BD=8F=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista-examples/Cargo.toml | 2 +- ballista/rust/client/Cargo.toml | 2 +- ballista/rust/scheduler/src/state/mod.rs | 95 +++++++++++++----------- benchmarks/Cargo.toml | 2 +- datafusion-cli/Cargo.toml | 2 +- 5 files changed, 57 insertions(+), 46 deletions(-) diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml index 338f69994bfd9..0461ccaae8f6f 100644 --- a/ballista-examples/Cargo.toml +++ b/ballista-examples/Cargo.toml @@ -29,7 +29,7 @@ publish = false rust-version = "1.57" [dependencies] -datafusion = { path = "../datafusion" } +datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } ballista = { path = "../ballista/rust/client", version = "0.6.0"} prost = "0.9" tonic = "0.6" diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index a9471ccb0aad3..a4dabc15f5d01 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -34,7 +34,7 @@ futures = "0.3" log = "0.4" tokio = "1.0" -datafusion = { path = "../../../datafusion"} +datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } [features] default = [] diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index ef6de83127027..483d4e513a9aa 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -320,34 +320,36 @@ impl SchedulerState { .await?; if task_is_dead { continue 'tasks; - } else if let Some(task_status::Status::Completed( - CompletedTask { - executor_id, - partitions, - }, - )) = &referenced_task.status - { - debug!("Task for unresolved shuffle input partition {} completed and produced these shuffle partitions:\n\t{}", - shuffle_input_partition_id, - partitions.iter().map(|p| format!("{}={}", p.partition_id, &p.path)).collect::>().join("\n\t") - ); - let stage_shuffle_partition_locations = partition_locations - .entry(unresolved_shuffle.stage_id) - .or_insert_with(HashMap::new); - let executor_meta = executors - .iter() - .find(|exec| exec.id == *executor_id) - .unwrap() - .clone(); - - for shuffle_write_partition in partitions { - let temp = stage_shuffle_partition_locations - .entry(shuffle_write_partition.partition_id as usize) - .or_insert_with(Vec::new); - let executor_meta = executor_meta.clone(); - let partition_location = - ballista_core::serde::scheduler::PartitionLocation { - partition_id: + } + + match &referenced_task.status { + Some(task_status::Status::Completed( + CompletedTask { + executor_id, + partitions, + }, + )) => { + debug!("Task for unresolved shuffle input partition {} completed and produced these shuffle partitions:\n\t{}", + shuffle_input_partition_id, + partitions.iter().map(|p| format!("{}={}", p.partition_id, &p.path)).collect::>().join("\n\t") + ); + let stage_shuffle_partition_locations = partition_locations + .entry(unresolved_shuffle.stage_id) + .or_insert_with(HashMap::new); + let executor_meta = executors + .iter() + .find(|exec| exec.id == *executor_id) + .unwrap() + .clone(); + + for shuffle_write_partition in partitions { + let temp = stage_shuffle_partition_locations + .entry(shuffle_write_partition.partition_id as usize) + .or_insert_with(Vec::new); + let executor_meta = executor_meta.clone(); + let partition_location = + ballista_core::serde::scheduler::PartitionLocation { + partition_id: ballista_core::serde::scheduler::PartitionId { job_id: partition.job_id.clone(), stage_id: unresolved_shuffle.stage_id, @@ -355,29 +357,38 @@ impl SchedulerState { .partition_id as usize, }, - executor_meta, - partition_stats: PartitionStats::new( - Some(shuffle_write_partition.num_rows), - Some(shuffle_write_partition.num_batches), - Some(shuffle_write_partition.num_bytes), - ), - path: shuffle_write_partition.path.clone(), - }; - debug!( + executor_meta, + partition_stats: PartitionStats::new( + Some(shuffle_write_partition.num_rows), + Some(shuffle_write_partition.num_batches), + Some(shuffle_write_partition.num_bytes), + ), + path: shuffle_write_partition.path.clone(), + }; + debug!( "Scheduler storing stage {} output partition {} path: {}", unresolved_shuffle.stage_id, partition_location.partition_id.partition_id, partition_location.path ); - temp.push(partition_location); + temp.push(partition_location); + } } - } else { - debug!( + Some(task_status::Status::Failed(FailedTask { error })) => { + // A task should fail when its referenced_task fails + let mut status = status.clone(); + let err_msg = format!("{}", error); + status.status = Some(task_status::Status::Failed(FailedTask { error: err_msg})); + self.save_task_status(&status).await?; + } + _ => { + debug!( "Stage {} input partition {} has not completed yet", unresolved_shuffle.stage_id, shuffle_input_partition_id, ); - continue 'tasks; - } + continue 'tasks; + } + }; } } diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index d20de3106bd32..4b68412d17a13 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -32,7 +32,7 @@ simd = ["datafusion/simd"] snmalloc = ["snmalloc-rs"] [dependencies] -datafusion = { path = "../datafusion" } +datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } ballista = { path = "../ballista/rust/client" } structopt = { version = "0.3", default-features = false } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread"] } diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index d5347d8e0009e..93fa5838c0e9d 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -30,6 +30,6 @@ rust-version = "1.57" clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } -datafusion = { path = "../datafusion", version = "6.0.0" } +datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } arrow = { version = "7.0.0" } ballista = { path = "../ballista/rust/client", version = "0.6.0" } From fffc70e6a5678e394f41ab2cb8b4a1055ee618d5 Mon Sep 17 00:00:00 2001 From: gaojun Date: Fri, 21 Jan 2022 16:53:48 +0800 Subject: [PATCH 04/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=8A=E6=B8=B8task?= =?UTF-8?q?=E5=A4=B1=E8=B4=A5=EF=BC=8C=E4=B8=8B=E6=B8=B8task=E4=B8=80?= =?UTF-8?q?=E7=9B=B4=E5=8D=A1=E4=BD=8F=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/scheduler/src/state/mod.rs | 43 ++++++++++++++---------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index 483d4e513a9aa..535ced9d85df4 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -323,19 +323,18 @@ impl SchedulerState { } match &referenced_task.status { - Some(task_status::Status::Completed( - CompletedTask { - executor_id, - partitions, - }, - )) => { + Some(task_status::Status::Completed(CompletedTask { + executor_id, + partitions, + })) => { debug!("Task for unresolved shuffle input partition {} completed and produced these shuffle partitions:\n\t{}", shuffle_input_partition_id, partitions.iter().map(|p| format!("{}={}", p.partition_id, &p.path)).collect::>().join("\n\t") ); - let stage_shuffle_partition_locations = partition_locations - .entry(unresolved_shuffle.stage_id) - .or_insert_with(HashMap::new); + let stage_shuffle_partition_locations = + partition_locations + .entry(unresolved_shuffle.stage_id) + .or_insert_with(HashMap::new); let executor_meta = executors .iter() .find(|exec| exec.id == *executor_id) @@ -344,7 +343,9 @@ impl SchedulerState { for shuffle_write_partition in partitions { let temp = stage_shuffle_partition_locations - .entry(shuffle_write_partition.partition_id as usize) + .entry( + shuffle_write_partition.partition_id as usize, + ) .or_insert_with(Vec::new); let executor_meta = executor_meta.clone(); let partition_location = @@ -365,11 +366,12 @@ impl SchedulerState { ), path: shuffle_write_partition.path.clone(), }; + debug!( - "Scheduler storing stage {} output partition {} path: {}", - unresolved_shuffle.stage_id, - partition_location.partition_id.partition_id, - partition_location.path + "Scheduler storing stage {} output partition {} path: {}", + unresolved_shuffle.stage_id, + partition_location.partition_id.partition_id, + partition_location.path ); temp.push(partition_location); } @@ -378,14 +380,19 @@ impl SchedulerState { // A task should fail when its referenced_task fails let mut status = status.clone(); let err_msg = format!("{}", error); - status.status = Some(task_status::Status::Failed(FailedTask { error: err_msg})); + status.status = + Some(task_status::Status::Failed(FailedTask { + error: err_msg, + })); self.save_task_status(&status).await?; + continue 'tasks; } _ => { debug!( - "Stage {} input partition {} has not completed yet", - unresolved_shuffle.stage_id, shuffle_input_partition_id, - ); + "Stage {} input partition {} has not completed yet", + unresolved_shuffle.stage_id, + shuffle_input_partition_id, + ); continue 'tasks; } }; From b2f2a6e89140d8048209e3c22e955d571aa25d29 Mon Sep 17 00:00:00 2001 From: gaojun Date: Fri, 21 Jan 2022 22:35:20 +0800 Subject: [PATCH 05/38] =?UTF-8?q?=E4=BF=AE=E6=94=B9submodule=E5=9C=B0?= =?UTF-8?q?=E5=9D=80=E5=88=B0=E6=98=93=E8=A7=82gitlib=E4=B8=8A=EF=BC=8C?= =?UTF-8?q?=E8=A7=A3=E5=86=B3argoengine=E6=9B=B4=E6=96=B0datafusion?= =?UTF-8?q?=E6=85=A2=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitmodules | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index ec5d6208b8ddb..ad5ffe40c05f1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "parquet-testing"] path = parquet-testing - url = https://github.com/apache/parquet-testing.git + url = http://git.analysysdata.com/noah/parquet-testing.git [submodule "testing"] path = testing - url = https://github.com/apache/arrow-testing + url = http://git.analysysdata.com/noah/arrow-testing.git From 264638ba1b15980dcfd5a40344411e146eef6563 Mon Sep 17 00:00:00 2001 From: gaojun Date: Sun, 23 Jan 2022 13:52:32 +0800 Subject: [PATCH 06/38] =?UTF-8?q?=E6=B7=BB=E5=8A=A0Decimal=E7=AD=89?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B=E7=9A=84=E5=BA=8F=E5=88=97?= =?UTF-8?q?=E5=8C=96=E5=8F=8D=E5=BA=8F=E5=88=97=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/core/proto/ballista.proto | 20 ++++ .../core/src/serde/logical_plan/from_proto.rs | 106 ++++++++++++++++-- .../core/src/serde/logical_plan/to_proto.rs | 52 ++++++++- ballista/rust/core/src/serde/mod.rs | 10 +- 4 files changed, 174 insertions(+), 14 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index b94bff4bed111..d042805d9aba2 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -1108,9 +1108,22 @@ message ScalarValue{ ScalarType null_list_value = 18; PrimitiveScalarType null_value = 19; + Decimal128 decimal128_value = 20; + int64 date_64_value = 21; + int64 time_second_value = 22; + int64 time_millisecond_value = 23; + int32 interval_yearmonth_value = 24; + int64 interval_daytime_value = 25; + } } +message Decimal128{ + string value = 1; + int64 p = 2; + int64 s = 3; +} + // Contains all valid datafusion scalar type except for // List enum PrimitiveScalarType{ @@ -1132,6 +1145,13 @@ enum PrimitiveScalarType{ TIME_MICROSECOND = 14; TIME_NANOSECOND = 15; NULL = 16; + DECIMAL128 = 17; + DATE64 = 20; + TIME_SECOND = 21; + TIME_MILLISECOND = 22; + INTERVAL_YEARMONTH = 23; + INTERVAL_DAYTIME = 24; + } message ScalarType{ diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 406b737d48837..19f895f44d988 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -17,10 +17,10 @@ //! Serde code to convert from protocol buffers to Rust data structures. -use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; use crate::error::BallistaError; use crate::serde::{from_proto_binary_op, proto_error, protobuf, str_to_byte}; use crate::{convert_box_required, convert_required}; +use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::datasource::file_format::avro::AvroFormat; use datafusion::datasource::file_format::csv::CsvFormat; @@ -541,12 +541,46 @@ fn typechecked_scalar_value_conversion( "Untyped scalar null is not a valid scalar value", )) } + + // argo engine add. + PrimitiveScalarType::Decimal128 => ScalarValue::Decimal128(None, 0, 0), + PrimitiveScalarType::Date64 => ScalarValue::Date64(None), + PrimitiveScalarType::TimeSecond => ScalarValue::TimestampSecond(None, None), + PrimitiveScalarType::TimeMillisecond => ScalarValue::TimestampMillisecond(None, None), + PrimitiveScalarType::IntervalYearmonth => ScalarValue::IntervalYearMonth(None), + PrimitiveScalarType::IntervalDaytime => ScalarValue::IntervalDayTime(None), + // argo engine add end. }; scalar_value } else { return Err(proto_error("Could not convert to the proper type")); } } + + // argo engine add. + (Value::Decimal128Value(val), PrimitiveScalarType::Decimal128) => { + ScalarValue::Decimal128( + Some(val.value.parse::().unwrap()), + val.p as usize, + val.s as usize, + ) + } + (Value::Date64Value(v), PrimitiveScalarType::Date64) => { + ScalarValue::Date64(Some(*v)) + } + (Value::TimeSecondValue(v), PrimitiveScalarType::TimeSecond) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + (Value::TimeMillisecondValue(v), PrimitiveScalarType::TimeMillisecond) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + (Value::IntervalYearmonthValue(v), PrimitiveScalarType::IntervalYearmonth) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + (Value::IntervalDaytimeValue(v), PrimitiveScalarType::IntervalDaytime) => { + ScalarValue::IntervalDayTime(Some(*v)) + } + // argo engine add end. _ => return Err(proto_error("Could not convert to the proper type")), }) } @@ -608,6 +642,30 @@ impl TryInto for &protobuf::scalar_value::Value .ok_or_else(|| proto_error("Invalid scalar type"))? .try_into()? } + + //argo engine add. + protobuf::scalar_value::Value::Decimal128Value(val) => { + ScalarValue::Decimal128( + Some(val.value.parse::().unwrap()), + val.p as usize, + val.s as usize, + ) + } + protobuf::scalar_value::Value::Date64Value(v) => { + ScalarValue::Date64(Some(*v)) + } + protobuf::scalar_value::Value::TimeSecondValue(v) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + protobuf::scalar_value::Value::TimeMillisecondValue(v) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + protobuf::scalar_value::Value::IntervalYearmonthValue(v) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + protobuf::scalar_value::Value::IntervalDaytimeValue(v) => { + ScalarValue::IntervalDayTime(Some(*v)) + } //argo engine add end. }; Ok(scalar) } @@ -764,6 +822,14 @@ impl TryInto for protobuf::PrimitiveScalarType protobuf::PrimitiveScalarType::TimeNanosecond => { ScalarValue::TimestampNanosecond(None, None) } + // argo engine add. + protobuf::PrimitiveScalarType::Decimal128 => ScalarValue::Decimal128(None, 0, 0), + protobuf::PrimitiveScalarType::Date64 => ScalarValue::Date64(None), + protobuf::PrimitiveScalarType::TimeSecond => ScalarValue::TimestampSecond(None, None), + protobuf::PrimitiveScalarType::TimeMillisecond => ScalarValue::TimestampMillisecond(None, None), + protobuf::PrimitiveScalarType::IntervalYearmonth => ScalarValue::IntervalYearMonth(None), + protobuf::PrimitiveScalarType::IntervalDaytime => ScalarValue::IntervalDayTime(None), + // argo engine add end. }) } } @@ -846,6 +912,30 @@ impl TryInto for &protobuf::ScalarValue { .ok_or_else(|| proto_error("Protobuf deserialization error found invalid enum variant for DatafusionScalar"))?; null_type_enum.try_into()? } + + //argo engine add. + protobuf::scalar_value::Value::Decimal128Value(val) => { + ScalarValue::Decimal128( + Some(val.value.parse::().unwrap()), + val.p as usize, + val.s as usize, + ) + } + protobuf::scalar_value::Value::Date64Value(v) => { + ScalarValue::Date64(Some(*v)) + } + protobuf::scalar_value::Value::TimeSecondValue(v) => { + ScalarValue::TimestampSecond(Some(*v), None) + } + protobuf::scalar_value::Value::TimeMillisecondValue(v) => { + ScalarValue::TimestampMillisecond(Some(*v), None) + } + protobuf::scalar_value::Value::IntervalYearmonthValue(v) => { + ScalarValue::IntervalYearMonth(Some(*v)) + } + protobuf::scalar_value::Value::IntervalDaytimeValue(v) => { + ScalarValue::IntervalDayTime(Some(*v)) + } //argo engine add end. }) } } @@ -969,14 +1059,10 @@ impl TryInto for &protobuf::LogicalExprNode { } // argo engine add start ExprType::AggregateUdfExpr(expr) => { - let fun = from_name_to_udaf(expr.fun_name.as_str()).map_err(|e| { - proto_error(format!( - "from_proto error: {}", - e - )) - })?; + let fun = from_name_to_udaf(expr.fun_name.as_str()) + .map_err(|e| proto_error(format!("from_proto error: {}", e)))?; let fun_arc = Arc::new(fun); - let fun_args= &expr.args; + let fun_args = &expr.args; let args: Vec = fun_args .iter() .map(|e| e.try_into()) @@ -1189,14 +1275,14 @@ impl TryInto for &protobuf::Field { } use crate::serde::protobuf::ColumnStats; +use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::{aggregates, windows}; use datafusion::prelude::{ array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper, }; -use std::convert::TryFrom; use futures::TryFutureExt; -use datafusion::physical_plan::udaf::AggregateUDF; +use std::convert::TryFrom; impl TryFrom for protobuf::FileType { type Error = BallistaError; diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 39f9afe737708..c84c60dfcd41e 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -41,6 +41,7 @@ use datafusion::logical_plan::{ }; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::functions::BuiltinScalarFunction; +use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -50,12 +51,11 @@ use protobuf::{ arrow_type, logical_expr_node::ExprType, scalar_type, DateUnit, PrimitiveScalarType, ScalarListValue, ScalarType, }; +use std::sync::Arc; use std::{ boxed, convert::{TryFrom, TryInto}, }; -use std::sync::Arc; -use datafusion::physical_plan::udaf::AggregateUDF; impl protobuf::IntervalUnit { pub fn from_arrow_interval_unit(interval_unit: &IntervalUnit) -> Self { @@ -565,6 +565,52 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { Value::TimeNanosecondValue(*s) }) } + + // argo engine add. + datafusion::scalar::ScalarValue::Decimal128(val, p, s) => { + match *val { + Some(v) => { + protobuf::ScalarValue { + value: Some(Value::Decimal128Value(protobuf::Decimal128 { + value: v.to_string(), + p: *p as i64, + s: *s as i64, + })), + } + } + None => { + protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue(PrimitiveScalarType::Decimal128 as i32)) + } + } + } + } + datafusion::scalar::ScalarValue::Date64(val) => { + create_proto_scalar(val, PrimitiveScalarType::Date64, |s| { + Value::Date64Value(*s) + }) + } + datafusion::scalar::ScalarValue::TimestampSecond(val, _) => { + create_proto_scalar(val, PrimitiveScalarType::TimeSecond, |s| { + Value::TimeSecondValue(*s) + }) + } + datafusion::scalar::ScalarValue::TimestampMillisecond(val, _) => { + create_proto_scalar(val, PrimitiveScalarType::TimeMillisecond, |s| { + Value::TimeMillisecondValue(*s) + }) + } + datafusion::scalar::ScalarValue::IntervalYearMonth(val) => { + create_proto_scalar(val, PrimitiveScalarType::IntervalYearmonth, |s| { + Value::IntervalYearmonthValue(*s) + }) + } + datafusion::scalar::ScalarValue::IntervalDayTime(val) => { + create_proto_scalar(val, PrimitiveScalarType::IntervalDaytime, |s| { + Value::IntervalDaytimeValue(*s) + }) + } + // argo engine add end. _ => { return Err(proto_error(format!( "Error converting to Datatype to scalar type, {:?} is invalid as a datafusion scalar.", @@ -1082,7 +1128,7 @@ impl TryInto for &Expr { } Expr::ScalarUDF { .. } => unimplemented!(), // argo engine add start - Expr::AggregateUDF { ref fun, ref args } => { + Expr::AggregateUDF { ref fun, ref args } => { let args: Vec = args .iter() .map(|e| e.try_into()) diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 62246a0232df4..7ffbea7c885b6 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -20,7 +20,7 @@ use std::{convert::TryInto, io::Cursor}; -use datafusion::arrow::datatypes::UnionMode; +use datafusion::arrow::datatypes::{IntervalUnit, UnionMode}; use datafusion::logical_plan::{JoinConstraint, JoinType, Operator}; use datafusion::physical_plan::aggregates::AggregateFunction; use datafusion::physical_plan::window_functions::BuiltInWindowFunction; @@ -314,6 +314,14 @@ impl Into for protobuf::PrimitiveScalarT DataType::Time64(TimeUnit::Nanosecond) } protobuf::PrimitiveScalarType::Null => DataType::Null, + // argo engine add. + protobuf::PrimitiveScalarType::Decimal128 => DataType::Decimal( 0, 0), + protobuf::PrimitiveScalarType::Date64 => DataType::Date64, + protobuf::PrimitiveScalarType::TimeSecond => DataType::Timestamp(TimeUnit::Second, None), + protobuf::PrimitiveScalarType::TimeMillisecond => DataType::Timestamp(TimeUnit::Millisecond, None), + protobuf::PrimitiveScalarType::IntervalYearmonth => DataType::Interval(IntervalUnit::YearMonth), + protobuf::PrimitiveScalarType::IntervalDaytime => DataType::Interval(IntervalUnit::DayTime), + // argo engine add end. } } } From 3523138fc15fd8ef6fc8165c8e99f7ce6cd41344 Mon Sep 17 00:00:00 2001 From: gaojun Date: Sun, 23 Jan 2022 16:45:07 +0800 Subject: [PATCH 07/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8DUDAF=E5=BA=8F=E5=88=97?= =?UTF-8?q?=E5=8C=96=E7=9A=84=E6=96=B9=E6=B3=95=E5=90=8D=E4=B8=8D=E5=AF=B9?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/core/src/serde/physical_plan/to_proto.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 78a50d78bce88..a759c655704fe 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -421,11 +421,12 @@ impl TryInto for Arc { .map(|e| e.clone().try_into()) .collect::, BallistaError>>()?; if self.as_any().downcast_ref::().is_some() { - let name = self.name(); + let name = self.name().to_string(); + let udaf_fun_name = &name[0..name.find('(').unwrap()]; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateUdfExpr( Box::new(protobuf::PhysicalAggregateUdfExprNode { - fun_name: name.to_string(), + fun_name: udaf_fun_name.to_string(), expr: Some(Box::new(expressions[0].clone())), }), )), From e6d0acd38fe9471f7a7682209faf77ab052ab033 Mon Sep 17 00:00:00 2001 From: gaojun Date: Sun, 23 Jan 2022 22:57:29 +0800 Subject: [PATCH 08/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8DUDAF=E5=BA=8F=E5=88=97?= =?UTF-8?q?=E5=8C=96=E7=9A=84=E5=8F=82=E6=95=B0=E5=8F=AA=E5=8F=96=E4=BA=86?= =?UTF-8?q?=E7=AC=AC=E4=B8=80=E4=B8=AA=E5=8F=82=E6=95=B0=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/core/proto/ballista.proto | 2 +- ballista/rust/core/src/serde/physical_plan/from_proto.rs | 9 ++++++++- ballista/rust/core/src/serde/physical_plan/to_proto.rs | 6 +++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index d042805d9aba2..83df3fe55bacd 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -536,7 +536,7 @@ message PhysicalAggregateExprNode { // argo engine add. message PhysicalAggregateUDFExprNode { string fun_name = 1; - PhysicalExprNode expr = 2; + repeated PhysicalExprNode expr = 2; } // argo engine add end. diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 854701a2cbbba..bc5b44edc7ed1 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -32,6 +32,7 @@ use crate::serde::scheduler::PartitionLocation; use crate::serde::{from_proto_binary_op, proto_error, protobuf, str_to_byte}; use crate::{convert_box_required, convert_required, into_required}; use chrono::{TimeZone, Utc}; +use datafusion::arrow::compute::eq_dyn; use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::catalog::catalog::{ CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, @@ -321,9 +322,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { e )) })?; + + let args: Vec> = agg_node.expr + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(create_aggregate_udf_expr( &fun, - &[convert_box_required!(agg_node.expr)?], + &args, &physical_schema, name.to_string(), )?) diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index a759c655704fe..42461f4ecef90 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -425,10 +425,10 @@ impl TryInto for Arc { let udaf_fun_name = &name[0..name.find('(').unwrap()]; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateUdfExpr( - Box::new(protobuf::PhysicalAggregateUdfExprNode { + protobuf::PhysicalAggregateUdfExprNode { fun_name: udaf_fun_name.to_string(), - expr: Some(Box::new(expressions[0].clone())), - }), + expr: expressions.clone(), + }, )), }) } else { From 9c6695d55cf1190ab329eb102a026d0d919fbb6c Mon Sep 17 00:00:00 2001 From: gaojun Date: Sun, 23 Jan 2022 23:04:03 +0800 Subject: [PATCH 09/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8DUDAF=E5=BA=8F=E5=88=97?= =?UTF-8?q?=E5=8C=96=E7=9A=84=E5=8F=82=E6=95=B0=E5=8F=AA=E5=8F=96=E4=BA=86?= =?UTF-8?q?=E7=AC=AC=E4=B8=80=E4=B8=AA=E5=8F=82=E6=95=B0=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/core/src/serde/physical_plan/from_proto.rs | 3 ++- ballista/rust/core/src/serde/physical_plan/to_proto.rs | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index bc5b44edc7ed1..580dd26b22213 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -316,7 +316,8 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } ExprType::AggregateUdfExpr(agg_node) => { let name = agg_node.fun_name.as_str(); - let fun = from_name_to_udaf(name).map_err(|e| { + let udaf_fun_name = &name[0..name.find('(').unwrap()]; + let fun = from_name_to_udaf(udaf_fun_name).map_err(|e| { proto_error(format!( "from_proto error: {}", e diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 42461f4ecef90..a74b00dc640cd 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -422,11 +422,10 @@ impl TryInto for Arc { .collect::, BallistaError>>()?; if self.as_any().downcast_ref::().is_some() { let name = self.name().to_string(); - let udaf_fun_name = &name[0..name.find('(').unwrap()]; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateUdfExpr( protobuf::PhysicalAggregateUdfExprNode { - fun_name: udaf_fun_name.to_string(), + fun_name: name.to_string(), expr: expressions.clone(), }, )), From 7ce997541d4ac5b7f2d2b5c544c80ed22814e135 Mon Sep 17 00:00:00 2001 From: gaojun Date: Mon, 24 Jan 2022 15:43:46 +0800 Subject: [PATCH 10/38] =?UTF-8?q?=E4=BB=A5bytes=E6=9D=A5=E5=BA=8F=E5=88=97?= =?UTF-8?q?=E5=8C=96Decimal128=E4=B8=AD=E7=9A=84=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/core/proto/ballista.proto | 2 +- ballista/rust/core/src/serde/logical_plan/from_proto.rs | 2 +- ballista/rust/core/src/serde/logical_plan/to_proto.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 83df3fe55bacd..aeb92961a61e6 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -1119,7 +1119,7 @@ message ScalarValue{ } message Decimal128{ - string value = 1; + bytes value = 1; int64 p = 2; int64 s = 3; } diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 19f895f44d988..0cc9389d0f4ef 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -916,7 +916,7 @@ impl TryInto for &protobuf::ScalarValue { //argo engine add. protobuf::scalar_value::Value::Decimal128Value(val) => { ScalarValue::Decimal128( - Some(val.value.parse::().unwrap()), + Some(i128::from_be_bytes(val.value)), val.p as usize, val.s as usize, ) diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index c84c60dfcd41e..fc299cc718890 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -572,7 +572,7 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { Some(v) => { protobuf::ScalarValue { value: Some(Value::Decimal128Value(protobuf::Decimal128 { - value: v.to_string(), + value: v.to_be_bytes(), p: *p as i64, s: *s as i64, })), From 26b6bff02f973cb062463f1ae436580597db9549 Mon Sep 17 00:00:00 2001 From: gaojun Date: Mon, 24 Jan 2022 16:37:11 +0800 Subject: [PATCH 11/38] =?UTF-8?q?=E4=BB=A5bytes=E6=9D=A5=E5=BA=8F=E5=88=97?= =?UTF-8?q?=E5=8C=96Decimal128=E4=B8=AD=E7=9A=84=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/src/serde/logical_plan/from_proto.rs | 53 ++++++++++++++----- .../core/src/serde/logical_plan/to_proto.rs | 4 +- ballista/rust/core/src/serde/mod.rs | 24 +++++++-- 3 files changed, 61 insertions(+), 20 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 0cc9389d0f4ef..b9d11c551f16e 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -18,7 +18,9 @@ //! Serde code to convert from protocol buffers to Rust data structures. use crate::error::BallistaError; -use crate::serde::{from_proto_binary_op, proto_error, protobuf, str_to_byte}; +use crate::serde::{ + from_proto_binary_op, proto_error, protobuf, str_to_byte, vec_to_array, +}; use crate::{convert_box_required, convert_required}; use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; @@ -543,12 +545,22 @@ fn typechecked_scalar_value_conversion( } // argo engine add. - PrimitiveScalarType::Decimal128 => ScalarValue::Decimal128(None, 0, 0), + PrimitiveScalarType::Decimal128 => { + ScalarValue::Decimal128(None, 0, 0) + } PrimitiveScalarType::Date64 => ScalarValue::Date64(None), - PrimitiveScalarType::TimeSecond => ScalarValue::TimestampSecond(None, None), - PrimitiveScalarType::TimeMillisecond => ScalarValue::TimestampMillisecond(None, None), - PrimitiveScalarType::IntervalYearmonth => ScalarValue::IntervalYearMonth(None), - PrimitiveScalarType::IntervalDaytime => ScalarValue::IntervalDayTime(None), + PrimitiveScalarType::TimeSecond => { + ScalarValue::TimestampSecond(None, None) + } + PrimitiveScalarType::TimeMillisecond => { + ScalarValue::TimestampMillisecond(None, None) + } + PrimitiveScalarType::IntervalYearmonth => { + ScalarValue::IntervalYearMonth(None) + } + PrimitiveScalarType::IntervalDaytime => { + ScalarValue::IntervalDayTime(None) + } // argo engine add end. }; scalar_value @@ -559,8 +571,9 @@ fn typechecked_scalar_value_conversion( // argo engine add. (Value::Decimal128Value(val), PrimitiveScalarType::Decimal128) => { + let array = vec_to_array(val.value.clone()); ScalarValue::Decimal128( - Some(val.value.parse::().unwrap()), + Some(i128::from_be_bytes(array)), val.p as usize, val.s as usize, ) @@ -645,8 +658,9 @@ impl TryInto for &protobuf::scalar_value::Value //argo engine add. protobuf::scalar_value::Value::Decimal128Value(val) => { + let array = vec_to_array(val.value.clone()); ScalarValue::Decimal128( - Some(val.value.parse::().unwrap()), + Some(i128::from_be_bytes(array)), val.p as usize, val.s as usize, ) @@ -823,12 +837,22 @@ impl TryInto for protobuf::PrimitiveScalarType ScalarValue::TimestampNanosecond(None, None) } // argo engine add. - protobuf::PrimitiveScalarType::Decimal128 => ScalarValue::Decimal128(None, 0, 0), + protobuf::PrimitiveScalarType::Decimal128 => { + ScalarValue::Decimal128(None, 0, 0) + } protobuf::PrimitiveScalarType::Date64 => ScalarValue::Date64(None), - protobuf::PrimitiveScalarType::TimeSecond => ScalarValue::TimestampSecond(None, None), - protobuf::PrimitiveScalarType::TimeMillisecond => ScalarValue::TimestampMillisecond(None, None), - protobuf::PrimitiveScalarType::IntervalYearmonth => ScalarValue::IntervalYearMonth(None), - protobuf::PrimitiveScalarType::IntervalDaytime => ScalarValue::IntervalDayTime(None), + protobuf::PrimitiveScalarType::TimeSecond => { + ScalarValue::TimestampSecond(None, None) + } + protobuf::PrimitiveScalarType::TimeMillisecond => { + ScalarValue::TimestampMillisecond(None, None) + } + protobuf::PrimitiveScalarType::IntervalYearmonth => { + ScalarValue::IntervalYearMonth(None) + } + protobuf::PrimitiveScalarType::IntervalDaytime => { + ScalarValue::IntervalDayTime(None) + } // argo engine add end. }) } @@ -915,8 +939,9 @@ impl TryInto for &protobuf::ScalarValue { //argo engine add. protobuf::scalar_value::Value::Decimal128Value(val) => { + let array = vec_to_array(val.value.clone()); ScalarValue::Decimal128( - Some(i128::from_be_bytes(val.value)), + Some(i128::from_be_bytes(array)), val.p as usize, val.s as usize, ) diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index fc299cc718890..910a1800be109 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -570,9 +570,11 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { datafusion::scalar::ScalarValue::Decimal128(val, p, s) => { match *val { Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.iter().cloned().collect(); protobuf::ScalarValue { value: Some(Value::Decimal128Value(protobuf::Decimal128 { - value: v.to_be_bytes(), + value: vec_val.clone(), p: *p as i64, s: *s as i64, })), diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 7ffbea7c885b6..0aec9134a3663 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -315,12 +315,20 @@ impl Into for protobuf::PrimitiveScalarT } protobuf::PrimitiveScalarType::Null => DataType::Null, // argo engine add. - protobuf::PrimitiveScalarType::Decimal128 => DataType::Decimal( 0, 0), + protobuf::PrimitiveScalarType::Decimal128 => DataType::Decimal(0, 0), protobuf::PrimitiveScalarType::Date64 => DataType::Date64, - protobuf::PrimitiveScalarType::TimeSecond => DataType::Timestamp(TimeUnit::Second, None), - protobuf::PrimitiveScalarType::TimeMillisecond => DataType::Timestamp(TimeUnit::Millisecond, None), - protobuf::PrimitiveScalarType::IntervalYearmonth => DataType::Interval(IntervalUnit::YearMonth), - protobuf::PrimitiveScalarType::IntervalDaytime => DataType::Interval(IntervalUnit::DayTime), + protobuf::PrimitiveScalarType::TimeSecond => { + DataType::Timestamp(TimeUnit::Second, None) + } + protobuf::PrimitiveScalarType::TimeMillisecond => { + DataType::Timestamp(TimeUnit::Millisecond, None) + } + protobuf::PrimitiveScalarType::IntervalYearmonth => { + DataType::Interval(IntervalUnit::YearMonth) + } + protobuf::PrimitiveScalarType::IntervalDaytime => { + DataType::Interval(IntervalUnit::DayTime) + } // argo engine add end. } } @@ -383,3 +391,9 @@ fn str_to_byte(s: &str) -> Result { } Ok(s.as_bytes()[0]) } + +fn vec_to_array(v: Vec) -> [T; N] { + v.try_into().unwrap_or_else(|v: Vec| { + panic!("Expected a Vec of length {} but it was {}", N, v.len()) + }) +} From 0d1db0a04e814debbb141a5a25749b056be7d2b5 Mon Sep 17 00:00:00 2001 From: gaojun Date: Mon, 24 Jan 2022 17:04:49 +0800 Subject: [PATCH 12/38] fix code style --- .../core/src/serde/logical_plan/from_proto.rs | 12 ++++------- .../core/src/serde/logical_plan/to_proto.rs | 4 ++-- ballista/rust/core/src/serde/mod.rs | 3 +-- .../core/src/serde/physical_plan/to_proto.rs | 20 ++++++++++++------- 4 files changed, 20 insertions(+), 19 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index b9d11c551f16e..866323e56c166 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -560,8 +560,7 @@ fn typechecked_scalar_value_conversion( } PrimitiveScalarType::IntervalDaytime => { ScalarValue::IntervalDayTime(None) - } - // argo engine add end. + } // argo engine add end. }; scalar_value } else { @@ -592,8 +591,7 @@ fn typechecked_scalar_value_conversion( } (Value::IntervalDaytimeValue(v), PrimitiveScalarType::IntervalDaytime) => { ScalarValue::IntervalDayTime(Some(*v)) - } - // argo engine add end. + } // argo engine add end. _ => return Err(proto_error("Could not convert to the proper type")), }) } @@ -852,8 +850,7 @@ impl TryInto for protobuf::PrimitiveScalarType } protobuf::PrimitiveScalarType::IntervalDaytime => { ScalarValue::IntervalDayTime(None) - } - // argo engine add end. + } // argo engine add end. }) } } @@ -1096,8 +1093,7 @@ impl TryInto for &protobuf::LogicalExprNode { fun: fun_arc, args: args.try_into().unwrap(), }) - } - // argo engine add end + } // argo engine add end ExprType::Alias(alias) => Ok(Expr::Alias( Box::new(parse_required_expr(&alias.expr)?), alias.alias.clone(), diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 910a1800be109..64e021371ccc9 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -571,10 +571,10 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { match *val { Some(v) => { let array = v.to_be_bytes(); - let vec_val: Vec = array.iter().cloned().collect(); + let vec_val: Vec = array.to_vec(); protobuf::ScalarValue { value: Some(Value::Decimal128Value(protobuf::Decimal128 { - value: vec_val.clone(), + value: vec_val, p: *p as i64, s: *s as i64, })), diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 0aec9134a3663..a8a0deb25054b 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -328,8 +328,7 @@ impl Into for protobuf::PrimitiveScalarT } protobuf::PrimitiveScalarType::IntervalDaytime => { DataType::Interval(IntervalUnit::DayTime) - } - // argo engine add end. + } // argo engine add end. } } } diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index a74b00dc640cd..908091f611cbd 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -420,15 +420,21 @@ impl TryInto for Arc { .iter() .map(|e| e.clone().try_into()) .collect::, BallistaError>>()?; - if self.as_any().downcast_ref::().is_some() { + if self + .as_any() + .downcast_ref::() + .is_some() + { let name = self.name().to_string(); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateUdfExpr( - protobuf::PhysicalAggregateUdfExprNode { - fun_name: name.to_string(), - expr: expressions.clone(), - }, - )), + expr_type: Some( + protobuf::physical_expr_node::ExprType::AggregateUdfExpr( + protobuf::PhysicalAggregateUdfExprNode { + fun_name: name.to_string(), + expr: expressions.clone(), + }, + ), + ), }) } else { let aggr_function = if self.as_any().downcast_ref::().is_some() { From ddcab1677f7a52b72a537563a5dcf4b7bc5c1d4f Mon Sep 17 00:00:00 2001 From: gaojun Date: Mon, 7 Feb 2022 18:07:49 +0800 Subject: [PATCH 13/38] =?UTF-8?q?=E4=BF=AE=E6=94=B9datafusion=E4=B8=AD?= =?UTF-8?q?=E7=9A=84udf.rs=E6=96=87=E4=BB=B6=EF=BC=8C=E6=B7=BB=E5=8A=A0udf?= =?UTF-8?q?=E8=87=AA=E5=B7=B1=E7=9A=84=E7=89=A9=E7=90=86=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E8=AE=A1=E5=88=92expr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/core/proto/ballista.proto | 13 ++++ .../core/src/serde/logical_plan/from_proto.rs | 17 ++++- .../core/src/serde/logical_plan/to_proto.rs | 17 ++++- .../src/serde/physical_plan/from_proto.rs | 76 ++++++++++++++++++- .../core/src/serde/physical_plan/to_proto.rs | 1 + datafusion/src/physical_plan/udf.rs | 73 ++++++++++++++++-- 6 files changed, 186 insertions(+), 11 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index aeb92961a61e6..4027f49101001 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -84,6 +84,7 @@ message LogicalExprNode { //ArgoEngineAggregateUDF expressions AggregateUDFExprNode aggregate_udf_expr = 19; +// ScalarUDFExprNode scalar_udf_expr = 20; } } @@ -524,6 +525,8 @@ message PhysicalExprNode { // argo engine add. PhysicalAggregateUDFExprNode aggregate_udf_expr = 16; + +// PhysicalScalarUDFExprNode scalar_udf_expr = 17; // argo engine add end. } } @@ -538,6 +541,11 @@ message PhysicalAggregateUDFExprNode { string fun_name = 1; repeated PhysicalExprNode expr = 2; } +// +//message PhysicalScalarUDFExprNode { +// string fun_name = 1; +// repeated PhysicalExprNode expr = 2; +//} // argo engine add end. message PhysicalWindowExprNode { @@ -998,6 +1006,11 @@ message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; } +// +//message ScalarUDFExprNode { +// string fun_name = 1; +// repeated LogicalExprNode args = 2; +//} /////////////////////////////////////////////////////////////////////////////////////////////////// // Arrow Data Types diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 866323e56c166..3647c14d4ccc7 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1093,7 +1093,21 @@ impl TryInto for &protobuf::LogicalExprNode { fun: fun_arc, args: args.try_into().unwrap(), }) - } // argo engine add end + } + // ExprType::ScalarUDFExpr(expr) => { + // let fun = from_name_to_udf(expr.fun_name.as_str()) + // .map_err(|e| proto_error(format!("from_proto error: {}", e)))?; + // let fun_arc = Arc::new(fun); + // let fun_args = &expr.args; + // let args: Vec = fun_args + // .iter() + // .map(|e| e.try_into()) + // .collect::, BallistaError>>()?; + // Ok(Expr::ScalarUDF { + // fun: fun_arc, + // args: args.try_into().unwrap(), + // }) + // } // argo engine add end ExprType::Alias(alias) => Ok(Expr::Alias( Box::new(parse_required_expr(&alias.expr)?), alias.alias.clone(), @@ -1304,6 +1318,7 @@ use datafusion::prelude::{ }; use futures::TryFutureExt; use std::convert::TryFrom; +use argo_engine_common::udf::argo_engine_udf::from_name_to_udf; impl TryFrom for protobuf::FileType { type Error = BallistaError; diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 64e021371ccc9..fc85089bcd4fb 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1128,8 +1128,23 @@ impl TryInto for &Expr { ), }) } - Expr::ScalarUDF { .. } => unimplemented!(), // argo engine add start + // Expr::ScalarUDF { ref fun, ref args } => { + // let args: Vec = args + // .iter() + // .map(|e| e.try_into()) + // .collect::, BallistaError>>()?; + // Ok(protobuf::LogicalExprNode { + // expr_type: Some( + // protobuf::logical_expr_node::ExprType::ScalarUDFExpr( + // protobuf::ScalarUDFExprNode { + // fun_name: fun.name.clone(), + // args, + // }, + // ), + // ), + // }) + // } Expr::AggregateUDF { ref fun, ref args } => { let args: Vec = args .iter() diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 580dd26b22213..8928eaeac9f1c 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -21,6 +21,7 @@ use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; +use argo_engine_common::udf::argo_engine_udf::from_name_to_udf; use crate::error::BallistaError; use crate::execution_plans::{ @@ -82,6 +83,7 @@ use datafusion::physical_plan::{ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion::physical_plan::udf::create_physical_expr; use datafusion::prelude::CsvReadOptions; use log::debug; use protobuf::physical_expr_node::ExprType; @@ -314,6 +316,28 @@ impl TryInto> for &protobuf::PhysicalPlanNode { name.to_string(), )?) } + // argo engine add. + // ExprType::ScalarUDFExpr(udf_node) => { + // let name = udf_node.fun_name.as_str(); + // let udf_fun_name = &name[0..name.find('(').unwrap()]; + // let fun = from_name_to_udf(udf_fun_name).map_err(|e| { + // proto_error(format!( + // "from_proto error: {}", + // e + // )) + // })?; + // + // let args: Vec> = udf_node.expr + // .iter() + // .map(|e| e.try_into()) + // .collect::, BallistaError>>()?; + // + // Ok(create_physical_expr( + // &fun, + // &args, + // &physical_schema, + // )?) + // } ExprType::AggregateUdfExpr(agg_node) => { let name = agg_node.fun_name.as_str(); let udaf_fun_name = &name[0..name.find('(').unwrap()]; @@ -335,7 +359,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { &physical_schema, name.to_string(), )?) - } + } // argo engine add end. _ => Err(BallistaError::General( "Invalid aggregate expression for HashAggregateExec" .to_string(), @@ -570,12 +594,60 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { .to_owned(), )); } + // argo engine add. + // ExprType::ScalarUDFExpr(e) => { + // let fun = from_name_to_udf(e.fun_name).map_err(|e| { + // proto_error(format!( + // "from_proto error: {}", + // e + // )) + // })?; + // + // let args = e + // .args + // .iter() + // .map(|x| x.try_into()) + // .collect::, _>>()?; + // + // let catalog_list = + // Arc::new(MemoryCatalogList::new()) as Arc; + // + // let ctx_state = ExecutionContextState { + // catalog_list, + // scalar_functions: Default::default(), + // var_provider: Default::default(), + // aggregate_functions: Default::default(), + // config: ExecutionConfig::new(), + // execution_props: ExecutionProps::new(), + // object_store_registry: Arc::new(ObjectStoreRegistry::new()), + // runtime_env: Arc::new(RuntimeEnv::default()), + // }; + // + // let fun_expr = fun.fun; + // + // Arc::new(ScalarFunctionExpr::new( + // &e.name, + // fun_expr, + // args, + // &convert_required!(e.return_type)?, + // )) + // + // let name = udf_node.fun_name.as_str(); + // let udf_fun_name = &name[0..name.find('(').unwrap()]; + // let fun = from_name_to_udf(udf_fun_name).map_err(|e| { + // proto_error(format!( + // "from_proto error: {}", + // e + // )) + // })?; + // Arc::new(fun) + // } ExprType::AggregateUdfExpr(_) => { return Err(BallistaError::General( "Cannot convert aggregate udf expr node to physical expression" .to_owned(), )); - } + } // argo engine add end. ExprType::WindowExpr(_) => { return Err(BallistaError::General( "Cannot convert window expr node to physical expression".to_owned(), diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 908091f611cbd..2b04addce318b 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -75,6 +75,7 @@ use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::functions::{BuiltinScalarFunction, ScalarFunctionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::udaf::AggregateFunctionExpr; +use datafusion::physical_plan::udf::ScalarUDF; impl TryInto for Arc { type Error = BallistaError; diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index af0765877c1b6..6e6dcb6426d0e 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -18,9 +18,11 @@ //! UDF support use fmt::{Debug, Formatter}; +use std::any::Any; use std::fmt; +use std::fmt::Display; -use arrow::datatypes::Schema; +use arrow::datatypes::{DataType, Schema}; use crate::error::Result; use crate::{logical_plan::Expr, physical_plan::PhysicalExpr}; @@ -31,6 +33,8 @@ use super::{ }, type_coercion::coerce, }; +use crate::physical_plan::ColumnarValue; +use arrow::record_batch::RecordBatch; use std::sync::Arc; /// Logical representation of a UDF. @@ -121,10 +125,65 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; - Ok(Arc::new(ScalarFunctionExpr::new( - &fun.name, - fun.fun.clone(), - coerced_phy_exprs, - (fun.return_type)(&coerced_exprs_types)?.as_ref(), - ))) + Ok(Arc::new(ScalarUDFExpr { + fun: fun.clone(), + name: fun.name.clone(), + args: coerced_phy_exprs.clone(), + return_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(), + })) } + +/// Physical expression of a UDF. +/// argo engine add +#[derive(Debug)] +pub struct ScalarUDFExpr { + fun: ScalarUDF, + name: String, + args: Vec>, + return_type: DataType, +} + +impl fmt::Display for ScalarUDFExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}({})", + self.name, + self.args + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} + +impl PhysicalExpr for ScalarUDFExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + // evaluate the arguments, if there are no arguments we'll instead pass in a null array + // indicating the batch size (as a convention) + // TODO need support zero input arguments + let inputs = self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + // evaluate the function + let fun = self.fun.fun.as_ref(); + (fun)(&inputs) + } +} // argo engine add end. From 3d5f74d8891ff3f9b2442a8b4dad06515e8ec7fc Mon Sep 17 00:00:00 2001 From: gaojun Date: Mon, 7 Feb 2022 22:59:32 +0800 Subject: [PATCH 14/38] =?UTF-8?q?=E6=B7=BB=E5=8A=A0ScalarUDFExpr=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0PhysicalExpr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista/rust/core/proto/ballista.proto | 25 ++--- .../core/src/serde/logical_plan/from_proto.rs | 28 +++--- .../core/src/serde/logical_plan/to_proto.rs | 32 +++---- .../src/serde/physical_plan/from_proto.rs | 92 +++++-------------- .../core/src/serde/physical_plan/to_proto.rs | 20 +++- 5 files changed, 84 insertions(+), 113 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 4027f49101001..699ebbf432620 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -84,7 +84,7 @@ message LogicalExprNode { //ArgoEngineAggregateUDF expressions AggregateUDFExprNode aggregate_udf_expr = 19; -// ScalarUDFExprNode scalar_udf_expr = 20; + ScalarUDFProtoExprNode scalar_udf_proto_expr = 20; } } @@ -526,7 +526,7 @@ message PhysicalExprNode { // argo engine add. PhysicalAggregateUDFExprNode aggregate_udf_expr = 16; -// PhysicalScalarUDFExprNode scalar_udf_expr = 17; + PhysicalScalarUDFProtoExprNode scalar_udf_proto_expr = 17; // argo engine add end. } } @@ -541,11 +541,12 @@ message PhysicalAggregateUDFExprNode { string fun_name = 1; repeated PhysicalExprNode expr = 2; } -// -//message PhysicalScalarUDFExprNode { -// string fun_name = 1; -// repeated PhysicalExprNode expr = 2; -//} + +message PhysicalScalarUDFProtoExprNode { + string fun_name = 1; + repeated PhysicalExprNode expr = 2; + ArrowType return_type = 3; +} // argo engine add end. message PhysicalWindowExprNode { @@ -1006,11 +1007,11 @@ message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; } -// -//message ScalarUDFExprNode { -// string fun_name = 1; -// repeated LogicalExprNode args = 2; -//} + +message ScalarUDFProtoExprNode { + string fun_name = 1; + repeated LogicalExprNode args = 2; +} /////////////////////////////////////////////////////////////////////////////////////////////////// // Arrow Data Types diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 3647c14d4ccc7..43a38e6024664 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1094,20 +1094,20 @@ impl TryInto for &protobuf::LogicalExprNode { args: args.try_into().unwrap(), }) } - // ExprType::ScalarUDFExpr(expr) => { - // let fun = from_name_to_udf(expr.fun_name.as_str()) - // .map_err(|e| proto_error(format!("from_proto error: {}", e)))?; - // let fun_arc = Arc::new(fun); - // let fun_args = &expr.args; - // let args: Vec = fun_args - // .iter() - // .map(|e| e.try_into()) - // .collect::, BallistaError>>()?; - // Ok(Expr::ScalarUDF { - // fun: fun_arc, - // args: args.try_into().unwrap(), - // }) - // } // argo engine add end + ExprType::ScalarUdfProtoExpr(expr) => { + let fun = from_name_to_udf(expr.fun_name.as_str()) + .map_err(|e| proto_error(format!("from_proto error: {}", e)))?; + let fun_arc = Arc::new(fun); + let fun_args = &expr.args; + let args: Vec = fun_args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(Expr::ScalarUDF { + fun: fun_arc, + args: args.try_into().unwrap(), + }) + } // argo engine add end ExprType::Alias(alias) => Ok(Expr::Alias( Box::new(parse_required_expr(&alias.expr)?), alias.alias.clone(), diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index fc85089bcd4fb..dcaf9b7f5eb8e 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1129,22 +1129,22 @@ impl TryInto for &Expr { }) } // argo engine add start - // Expr::ScalarUDF { ref fun, ref args } => { - // let args: Vec = args - // .iter() - // .map(|e| e.try_into()) - // .collect::, BallistaError>>()?; - // Ok(protobuf::LogicalExprNode { - // expr_type: Some( - // protobuf::logical_expr_node::ExprType::ScalarUDFExpr( - // protobuf::ScalarUDFExprNode { - // fun_name: fun.name.clone(), - // args, - // }, - // ), - // ), - // }) - // } + Expr::ScalarUDF { ref fun, ref args } => { + let args: Vec = args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(protobuf::LogicalExprNode { + expr_type: Some( + protobuf::logical_expr_node::ExprType::ScalarUdfProtoExpr( + protobuf::ScalarUdfProtoExprNode { + fun_name: fun.name.clone(), + args, + }, + ), + ), + }) + } Expr::AggregateUDF { ref fun, ref args } => { let args: Vec = args .iter() diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 8928eaeac9f1c..c1c19b724717c 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -83,7 +83,7 @@ use datafusion::physical_plan::{ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; -use datafusion::physical_plan::udf::create_physical_expr; +use datafusion::physical_plan::udf::{create_physical_expr, ScalarUDFExpr}; use datafusion::prelude::CsvReadOptions; use log::debug; use protobuf::physical_expr_node::ExprType; @@ -316,28 +316,6 @@ impl TryInto> for &protobuf::PhysicalPlanNode { name.to_string(), )?) } - // argo engine add. - // ExprType::ScalarUDFExpr(udf_node) => { - // let name = udf_node.fun_name.as_str(); - // let udf_fun_name = &name[0..name.find('(').unwrap()]; - // let fun = from_name_to_udf(udf_fun_name).map_err(|e| { - // proto_error(format!( - // "from_proto error: {}", - // e - // )) - // })?; - // - // let args: Vec> = udf_node.expr - // .iter() - // .map(|e| e.try_into()) - // .collect::, BallistaError>>()?; - // - // Ok(create_physical_expr( - // &fun, - // &args, - // &physical_schema, - // )?) - // } ExprType::AggregateUdfExpr(agg_node) => { let name = agg_node.fun_name.as_str(); let udaf_fun_name = &name[0..name.find('(').unwrap()]; @@ -595,53 +573,27 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { )); } // argo engine add. - // ExprType::ScalarUDFExpr(e) => { - // let fun = from_name_to_udf(e.fun_name).map_err(|e| { - // proto_error(format!( - // "from_proto error: {}", - // e - // )) - // })?; - // - // let args = e - // .args - // .iter() - // .map(|x| x.try_into()) - // .collect::, _>>()?; - // - // let catalog_list = - // Arc::new(MemoryCatalogList::new()) as Arc; - // - // let ctx_state = ExecutionContextState { - // catalog_list, - // scalar_functions: Default::default(), - // var_provider: Default::default(), - // aggregate_functions: Default::default(), - // config: ExecutionConfig::new(), - // execution_props: ExecutionProps::new(), - // object_store_registry: Arc::new(ObjectStoreRegistry::new()), - // runtime_env: Arc::new(RuntimeEnv::default()), - // }; - // - // let fun_expr = fun.fun; - // - // Arc::new(ScalarFunctionExpr::new( - // &e.name, - // fun_expr, - // args, - // &convert_required!(e.return_type)?, - // )) - // - // let name = udf_node.fun_name.as_str(); - // let udf_fun_name = &name[0..name.find('(').unwrap()]; - // let fun = from_name_to_udf(udf_fun_name).map_err(|e| { - // proto_error(format!( - // "from_proto error: {}", - // e - // )) - // })?; - // Arc::new(fun) - // } + ExprType::ScalarUDFProtoExpr(e) => { + let fun = from_name_to_udf(e.fun_name).map_err(|e| { + proto_error(format!( + "from_proto error: {}", + e + )) + })?; + + let args = e + .args + .iter() + .map(|x| x.try_into()) + .collect::, _>>()?; + + Arc::new(ScalarUDFExpr{ + fun: fun.clone(), + name: e.fun_name, + args, + return_type: convert_required!(e.return_type)?, + }) + } ExprType::AggregateUdfExpr(_) => { return Err(BallistaError::General( "Cannot convert aggregate udf expr node to physical expression" diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 2b04addce318b..b7220798ab09f 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -75,7 +75,7 @@ use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::functions::{BuiltinScalarFunction, ScalarFunctionExpr}; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::udaf::AggregateFunctionExpr; -use datafusion::physical_plan::udf::ScalarUDF; +use datafusion::physical_plan::udf::{ScalarUDF, ScalarUDFExpr}; impl TryInto for Arc { type Error = BallistaError; @@ -619,6 +619,24 @@ impl TryFrom> for protobuf::PhysicalExprNode { }, )), }) + } else if let Some(expr) = expr.downcast_ref::() { + let fun: ScalarUDF = expr.fun.clone(); + let args: Vec = expr + .args() + .iter() + .map(|e| e.to_owned().try_into()) + .collect::, _>>()?; + let data_type = expr.return_type.clone(); + let return_type = (&data_type).into(); + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdfProtoExpr( + protobuf::PhysicalScalarUdfProtoExprNode { + fun_name: expr.name.to_string(), + expr: args, + return_type: Some(return_type), + }, + )), + }) } else { Err(BallistaError::General(format!( "physical_plan::to_proto() unsupported expression {:?}", From e2db015e17e2de8f820b2819531cbc5666095f05 Mon Sep 17 00:00:00 2001 From: gaojun Date: Mon, 7 Feb 2022 23:07:10 +0800 Subject: [PATCH 15/38] =?UTF-8?q?=E6=B7=BB=E5=8A=A0ScalarUDFExpr=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0PhysicalExpr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/src/serde/physical_plan/to_proto.rs | 2 +- datafusion/src/physical_plan/udf.rs | 33 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index b7220798ab09f..fce7a7b9ea2b0 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -620,7 +620,7 @@ impl TryFrom> for protobuf::PhysicalExprNode { )), }) } else if let Some(expr) = expr.downcast_ref::() { - let fun: ScalarUDF = expr.fun.clone(); + let fun: ScalarUDF = expr.; let args: Vec = expr .args() .iter() diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 6e6dcb6426d0e..fa4575d0fb864 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -143,6 +143,39 @@ pub struct ScalarUDFExpr { return_type: DataType, } +impl ScalarUDFExpr { + pub fn new( + name: &str, + fun: ScalarUDF, + args: Vec>, + return_type: &DataType, + ) -> Self { + Self { + fun, + name: name.to_string(), + args, + return_type: return_type.clone(), + } + } + + pub fn fun(&self) -> &ScalarUDF { + &self.fun + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn args(&self) -> &[Arc] { + &self.args + } + + /// Data type produced by this expression + pub fn return_type(&self) -> &DataType { + &self.return_type + } +} + impl fmt::Display for ScalarUDFExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( From 4fd24b6c70ba4a56a775d367d8bf1182c9179518 Mon Sep 17 00:00:00 2001 From: gaojun Date: Tue, 8 Feb 2022 10:11:07 +0800 Subject: [PATCH 16/38] =?UTF-8?q?=E7=BB=99ballista=E6=B7=BB=E5=8A=A0udf?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/src/serde/logical_plan/from_proto.rs | 12 +++------ .../src/serde/physical_plan/from_proto.rs | 26 ++++++++----------- .../core/src/serde/physical_plan/to_proto.rs | 24 ++++++++--------- datafusion/src/physical_plan/udf.rs | 5 +++- 4 files changed, 30 insertions(+), 37 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 43a38e6024664..dc621a9b0fe67 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1089,10 +1089,7 @@ impl TryInto for &protobuf::LogicalExprNode { .iter() .map(|e| e.try_into()) .collect::, BallistaError>>()?; - Ok(Expr::AggregateUDF { - fun: fun_arc, - args: args.try_into().unwrap(), - }) + Ok(Expr::AggregateUDF { fun: fun_arc, args }) } ExprType::ScalarUdfProtoExpr(expr) => { let fun = from_name_to_udf(expr.fun_name.as_str()) @@ -1103,10 +1100,7 @@ impl TryInto for &protobuf::LogicalExprNode { .iter() .map(|e| e.try_into()) .collect::, BallistaError>>()?; - Ok(Expr::ScalarUDF { - fun: fun_arc, - args: args.try_into().unwrap(), - }) + Ok(Expr::ScalarUDF { fun: fun_arc, args }) } // argo engine add end ExprType::Alias(alias) => Ok(Expr::Alias( Box::new(parse_required_expr(&alias.expr)?), @@ -1310,6 +1304,7 @@ impl TryInto for &protobuf::Field { } use crate::serde::protobuf::ColumnStats; +use argo_engine_common::udf::argo_engine_udf::from_name_to_udf; use datafusion::physical_plan::udaf::AggregateUDF; use datafusion::physical_plan::{aggregates, windows}; use datafusion::prelude::{ @@ -1318,7 +1313,6 @@ use datafusion::prelude::{ }; use futures::TryFutureExt; use std::convert::TryFrom; -use argo_engine_common::udf::argo_engine_udf::from_name_to_udf; impl TryFrom for protobuf::FileType { type Error = BallistaError; diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index c1c19b724717c..37249b9abe33f 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -18,10 +18,10 @@ //! Serde code to convert from protocol buffers to Rust data structures. use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; +use argo_engine_common::udf::argo_engine_udf::from_name_to_udf; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; -use argo_engine_common::udf::argo_engine_udf::from_name_to_udf; use crate::error::BallistaError; use crate::execution_plans::{ @@ -59,6 +59,7 @@ use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::sorts::sort::{SortExec, SortOptions}; use datafusion::physical_plan::udaf::create_aggregate_expr as create_aggregate_udf_expr; +use datafusion::physical_plan::udf::{create_physical_expr, ScalarUDFExpr}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -83,7 +84,6 @@ use datafusion::physical_plan::{ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; -use datafusion::physical_plan::udf::{create_physical_expr, ScalarUDFExpr}; use datafusion::prelude::CsvReadOptions; use log::debug; use protobuf::physical_expr_node::ExprType; @@ -573,26 +573,22 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { )); } // argo engine add. - ExprType::ScalarUDFProtoExpr(e) => { - let fun = from_name_to_udf(e.fun_name).map_err(|e| { - proto_error(format!( - "from_proto error: {}", - e - )) - })?; + ExprType::ScalarUdfProtoExpr(e) => { + let fun = from_name_to_udf(&e.fun_name) + .map_err(|e| proto_error(format!("from_proto error: {}", e)))?; let args = e - .args + .expr .iter() .map(|x| x.try_into()) .collect::, _>>()?; - Arc::new(ScalarUDFExpr{ - fun: fun.clone(), - name: e.fun_name, + Arc::new(ScalarUDFExpr::new( + e.fun_name.as_str(), + fun, args, - return_type: convert_required!(e.return_type)?, - }) + &convert_required!(e.return_type)?, + )) } ExprType::AggregateUdfExpr(_) => { return Err(BallistaError::General( diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index fce7a7b9ea2b0..357fb9f1910cc 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -426,13 +426,12 @@ impl TryInto for Arc { .downcast_ref::() .is_some() { - let name = self.name().to_string(); Ok(protobuf::PhysicalExprNode { expr_type: Some( protobuf::physical_expr_node::ExprType::AggregateUdfExpr( protobuf::PhysicalAggregateUdfExprNode { - fun_name: name.to_string(), - expr: expressions.clone(), + fun_name: self.name().to_string(), + expr: expressions, }, ), ), @@ -620,22 +619,23 @@ impl TryFrom> for protobuf::PhysicalExprNode { )), }) } else if let Some(expr) = expr.downcast_ref::() { - let fun: ScalarUDF = expr.; let args: Vec = expr .args() .iter() .map(|e| e.to_owned().try_into()) .collect::, _>>()?; - let data_type = expr.return_type.clone(); + let data_type = expr.return_type().clone(); let return_type = (&data_type).into(); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdfProtoExpr( - protobuf::PhysicalScalarUdfProtoExprNode { - fun_name: expr.name.to_string(), - expr: args, - return_type: Some(return_type), - }, - )), + expr_type: Some( + protobuf::physical_expr_node::ExprType::ScalarUdfProtoExpr( + protobuf::PhysicalScalarUdfProtoExprNode { + fun_name: expr.name().to_string(), + expr: args, + return_type: Some(return_type), + }, + ), + ), }) } else { Err(BallistaError::General(format!( diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index fa4575d0fb864..4d8fdd8b57b8f 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -20,7 +20,6 @@ use fmt::{Debug, Formatter}; use std::any::Any; use std::fmt; -use std::fmt::Display; use arrow::datatypes::{DataType, Schema}; @@ -144,6 +143,7 @@ pub struct ScalarUDFExpr { } impl ScalarUDFExpr { + /// create a ScalarUDFExpr pub fn new( name: &str, fun: ScalarUDF, @@ -158,14 +158,17 @@ impl ScalarUDFExpr { } } + /// return fun pub fn fun(&self) -> &ScalarUDF { &self.fun } + /// return name pub fn name(&self) -> &str { &self.name } + /// return args pub fn args(&self) -> &[Arc] { &self.args } From 9a26e5cc4531cfde804c96f01be7969ffda0b88e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E6=8C=AF=E5=85=B4?= Date: Tue, 8 Feb 2022 06:35:48 +0000 Subject: [PATCH 17/38] Update .gitmodules --- .gitmodules | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index ad5ffe40c05f1..3670814a9bd9c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "parquet-testing"] path = parquet-testing - url = http://git.analysysdata.com/noah/parquet-testing.git + url = http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/parquet-testing.git [submodule "testing"] path = testing - url = http://git.analysysdata.com/noah/arrow-testing.git + url = http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-testing.git From 4fb1487d500a5e6fef988a35713a65a0ffad4eea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E6=8C=AF=E5=85=B4?= Date: Tue, 8 Feb 2022 06:53:22 +0000 Subject: [PATCH 18/38] Update Cargo.toml --- ballista/rust/scheduler/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index 3cae3f7fcad0a..7f412ab9e906c 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -35,7 +35,7 @@ anyhow = "1" ballista-core = { path = "../core"} clap = "2" configure_me = "0.4.0" -datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } env_logger = "0.9" etcd-client = { version = "0.7", optional = true } futures = "0.3" From d28ca75e7930bcc0b7d5a8f7d62f0874082dedd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E6=8C=AF=E5=85=B4?= Date: Tue, 8 Feb 2022 06:53:39 +0000 Subject: [PATCH 19/38] Update Cargo.toml --- ballista/rust/executor/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index d8c2aafd2b518..1813a025f5746 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -35,7 +35,7 @@ anyhow = "1" async-trait = "0.1.36" ballista-core = { path = "../core"} configure_me = "0.4.0" -datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } env_logger = "0.9" futures = "0.3" log = "0.4" From 52134acd271df1ff1dfedd92ab411f269fb52f5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E6=8C=AF=E5=85=B4?= Date: Tue, 8 Feb 2022 06:53:59 +0000 Subject: [PATCH 20/38] Update Cargo.toml --- ballista/rust/core/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index fc1a550ed902c..8a3991acf7a93 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -45,8 +45,8 @@ chrono = { version = "0.4", default-features = false } arrow-flight = { version = "7.0.0" } -datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } -argo-engine-common = { git = "http://git.analysysdata.com/noah/argo_engine.git", branch="master", package = "common" } +datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +argo-engine-common = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/argo_engine.git", branch="master", package = "common" } [dev-dependencies] tempfile = "3" From 104e0e0fdfaa9258c0735fc84698db645c14fdde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E6=8C=AF=E5=85=B4?= Date: Tue, 8 Feb 2022 06:54:15 +0000 Subject: [PATCH 21/38] Update Cargo.toml --- ballista/rust/client/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index a4dabc15f5d01..99c07ba2186b3 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -34,7 +34,7 @@ futures = "0.3" log = "0.4" tokio = "1.0" -datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } [features] default = [] From b2f9e41e360f0c6e3c74fd16488cae8650f7ac87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E6=8C=AF=E5=85=B4?= Date: Tue, 8 Feb 2022 06:59:14 +0000 Subject: [PATCH 22/38] Update Cargo.toml --- ballista-examples/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml index 0461ccaae8f6f..7a19f6346b614 100644 --- a/ballista-examples/Cargo.toml +++ b/ballista-examples/Cargo.toml @@ -29,7 +29,7 @@ publish = false rust-version = "1.57" [dependencies] -datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } ballista = { path = "../ballista/rust/client", version = "0.6.0"} prost = "0.9" tonic = "0.6" From e8b14864df8415bcc72297bb4999e7b3bf635a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E6=8C=AF=E5=85=B4?= Date: Tue, 8 Feb 2022 07:00:52 +0000 Subject: [PATCH 23/38] Update Cargo.toml --- benchmarks/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 4b68412d17a13..d7fa1c1ae942f 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -32,7 +32,7 @@ simd = ["datafusion/simd"] snmalloc = ["snmalloc-rs"] [dependencies] -datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } ballista = { path = "../ballista/rust/client" } structopt = { version = "0.3", default-features = false } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread"] } From b78a1eba11149df5050f6c59a89d637537c2c86d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E6=8C=AF=E5=85=B4?= Date: Tue, 8 Feb 2022 07:01:10 +0000 Subject: [PATCH 24/38] Update Cargo.toml --- datafusion-cli/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 93fa5838c0e9d..1e45bf9bceff8 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -30,6 +30,6 @@ rust-version = "1.57" clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } -datafusion = { git = "http://git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } arrow = { version = "7.0.0" } ballista = { path = "../ballista/rust/client", version = "0.6.0" } From b2a0987582b9394e59fe100904bf430b51d38f02 Mon Sep 17 00:00:00 2001 From: gaojun Date: Tue, 15 Feb 2022 11:26:48 +0800 Subject: [PATCH 25/38] =?UTF-8?q?udf=E5=92=8Cudaf=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ballista-examples/Cargo.toml | 2 +- ballista/rust/client/Cargo.toml | 2 +- ballista/rust/core/Cargo.toml | 4 +- .../core/src/serde/logical_plan/from_proto.rs | 34 +++- .../src/serde/physical_plan/from_proto.rs | 35 +++- ballista/rust/executor/Cargo.toml | 2 +- ballista/rust/scheduler/Cargo.toml | 2 +- benchmarks/Cargo.toml | 2 +- datafusion-cli/Cargo.toml | 2 +- datafusion/Cargo.toml | 2 + datafusion/src/execution/context.rs | 44 ++++- datafusion/src/execution/mod.rs | 39 ++++ .../src/execution/udaf_plugin_manager.rs | 170 ++++++++++++++++++ .../src/execution/udf_plugin_manager.rs | 164 +++++++++++++++++ datafusion/src/physical_plan/mod.rs | 10 ++ datafusion/src/physical_plan/udaf.rs | 56 ++++++ datafusion/src/physical_plan/udf.rs | 60 ++++++- 17 files changed, 599 insertions(+), 31 deletions(-) create mode 100644 datafusion/src/execution/udaf_plugin_manager.rs create mode 100644 datafusion/src/execution/udf_plugin_manager.rs diff --git a/ballista-examples/Cargo.toml b/ballista-examples/Cargo.toml index 7a19f6346b614..09923d54df82c 100644 --- a/ballista-examples/Cargo.toml +++ b/ballista-examples/Cargo.toml @@ -29,7 +29,7 @@ publish = false rust-version = "1.57" [dependencies] -datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client", version = "0.6.0"} prost = "0.9" tonic = "0.6" diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml index 99c07ba2186b3..7e3006bfcf2a1 100644 --- a/ballista/rust/client/Cargo.toml +++ b/ballista/rust/client/Cargo.toml @@ -34,7 +34,7 @@ futures = "0.3" log = "0.4" tokio = "1.0" -datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { path = "../../../datafusion" } [features] default = [] diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 8a3991acf7a93..15cb530130dcd 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -42,11 +42,9 @@ tokio = "1.0" tonic = "0.6" uuid = { version = "0.8", features = ["v4"] } chrono = { version = "0.4", default-features = false } - arrow-flight = { version = "7.0.0" } -datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } -argo-engine-common = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/argo_engine.git", branch="master", package = "common" } +datafusion = { path = "../../../datafusion" } [dev-dependencies] tempfile = "3" diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index dc621a9b0fe67..55b70a3eddee9 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -22,7 +22,6 @@ use crate::serde::{ from_proto_binary_op, proto_error, protobuf, str_to_byte, vec_to_array, }; use crate::{convert_box_required, convert_required}; -use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::datasource::file_format::avro::AvroFormat; use datafusion::datasource::file_format::csv::CsvFormat; @@ -1081,8 +1080,16 @@ impl TryInto for &protobuf::LogicalExprNode { } // argo engine add start ExprType::AggregateUdfExpr(expr) => { - let fun = from_name_to_udaf(expr.fun_name.as_str()) - .map_err(|e| proto_error(format!("from_proto error: {}", e)))?; + let fun = UDAF_PLUGIN_MANAGER + .aggregate_udf_plugins.get(expr.fun_name.as_str()).ok_or_else(|| { + proto_error(format!( + "can not get udaf:{} from UDAF_PLUGIN_MANAGER.aggregate_udf_plugins!", + expr.fun_name.to_string() + )) + })?; + let fun = fun + .get_aggregate_udf_by_name(expr.fun_name.as_str()) + .map_err(|e| BallistaError::DataFusionError(e))?; let fun_arc = Arc::new(fun); let fun_args = &expr.args; let args: Vec = fun_args @@ -1092,8 +1099,18 @@ impl TryInto for &protobuf::LogicalExprNode { Ok(Expr::AggregateUDF { fun: fun_arc, args }) } ExprType::ScalarUdfProtoExpr(expr) => { - let fun = from_name_to_udf(expr.fun_name.as_str()) - .map_err(|e| proto_error(format!("from_proto error: {}", e)))?; + let fun = UDF_PLUGIN_MANAGER + .scalar_udfs + .get(expr.fun_name.as_str()) + .ok_or_else(|| { + proto_error(format!( + "can not get udf:{} from UDF_PLUGIN_MANAGER.scalar_udfs!", + expr.fun_name.to_string() + )) + })?; + let fun = fun + .get_scalar_udf_by_name(expr.fun_name.as_str()) + .map_err(|e| BallistaError::DataFusionError(e))?; let fun_arc = Arc::new(fun); let fun_args = &expr.args; let args: Vec = fun_args @@ -1304,8 +1321,11 @@ impl TryInto for &protobuf::Field { } use crate::serde::protobuf::ColumnStats; -use argo_engine_common::udf::argo_engine_udf::from_name_to_udf; -use datafusion::physical_plan::udaf::AggregateUDF; +use datafusion::execution::udaf_plugin_manager::UDAF_PLUGIN_MANAGER; +use datafusion::execution::udf_plugin_manager::UDF_PLUGIN_MANAGER; +use datafusion::execution::PluginManager; +use datafusion::physical_plan::udaf::{AggregateUDF, UDAFPlugin}; +use datafusion::physical_plan::udf::UDFPlugin; use datafusion::physical_plan::{aggregates, windows}; use datafusion::prelude::{ array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256, diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 37249b9abe33f..6b93db2e9a692 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -17,8 +17,6 @@ //! Serde code to convert from protocol buffers to Rust data structures. -use argo_engine_common::udaf::argo_engine_udaf::from_name_to_udaf; -use argo_engine_common::udf::argo_engine_udf::from_name_to_udf; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; @@ -45,6 +43,8 @@ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::udaf_plugin_manager::UDAF_PLUGIN_MANAGER; +use datafusion::execution::udf_plugin_manager::UDF_PLUGIN_MANAGER; use datafusion::logical_plan::{ window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType, }; @@ -58,8 +58,10 @@ use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::sorts::sort::{SortExec, SortOptions}; -use datafusion::physical_plan::udaf::create_aggregate_expr as create_aggregate_udf_expr; -use datafusion::physical_plan::udf::{create_physical_expr, ScalarUDFExpr}; +use datafusion::physical_plan::udaf::{ + create_aggregate_expr as create_aggregate_udf_expr, UDAFPlugin, +}; +use datafusion::physical_plan::udf::{create_physical_expr, ScalarUDFExpr, UDFPlugin}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -319,12 +321,16 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ExprType::AggregateUdfExpr(agg_node) => { let name = agg_node.fun_name.as_str(); let udaf_fun_name = &name[0..name.find('(').unwrap()]; - let fun = from_name_to_udaf(udaf_fun_name).map_err(|e| { + let fun = UDAF_PLUGIN_MANAGER + .aggregate_udf_plugins.get(udaf_fun_name).ok_or_else(|| { proto_error(format!( - "from_proto error: {}", - e + "can not get udaf:{} from UDAF_PLUGIN_MANAGER.aggregate_udf_plugins!", + udaf_fun_name.to_string() )) })?; + let fun = fun + .get_aggregate_udf_by_name(udaf_fun_name) + .map_err(|e| BallistaError::DataFusionError(e))?; let args: Vec> = agg_node.expr .iter() @@ -574,8 +580,19 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { } // argo engine add. ExprType::ScalarUdfProtoExpr(e) => { - let fun = from_name_to_udf(&e.fun_name) - .map_err(|e| proto_error(format!("from_proto error: {}", e)))?; + let fun = + UDF_PLUGIN_MANAGER + .scalar_udfs + .get(&e.fun_name) + .ok_or_else(|| { + proto_error(format!( + "can not get udf:{} from UDF_PLUGIN_MANAGER.scalar_udfs!", + &e.fun_name.to_owned() + )) + })?; + let fun = fun + .get_scalar_udf_by_name(&e.fun_name.as_str()) + .map_err(|e| BallistaError::DataFusionError(e))?; let args = e .expr diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index 1813a025f5746..d0b4186ec00c8 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -35,7 +35,7 @@ anyhow = "1" async-trait = "0.1.36" ballista-core = { path = "../core"} configure_me = "0.4.0" -datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { path = "../../../datafusion" } env_logger = "0.9" futures = "0.3" log = "0.4" diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml index 7f412ab9e906c..3799befffa03a 100644 --- a/ballista/rust/scheduler/Cargo.toml +++ b/ballista/rust/scheduler/Cargo.toml @@ -35,7 +35,7 @@ anyhow = "1" ballista-core = { path = "../core"} clap = "2" configure_me = "0.4.0" -datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { path = "../../../datafusion" } env_logger = "0.9" etcd-client = { version = "0.7", optional = true } futures = "0.3" diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index d7fa1c1ae942f..d2e4b3143ec0a 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -32,7 +32,7 @@ simd = ["datafusion/simd"] snmalloc = ["snmalloc-rs"] [dependencies] -datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { path = "../datafusion" } ballista = { path = "../ballista/rust/client" } structopt = { version = "0.3", default-features = false } tokio = { version = "^1.0", features = ["macros", "rt", "rt-multi-thread"] } diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 1e45bf9bceff8..97215e254fa03 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -30,6 +30,6 @@ rust-version = "1.57" clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } -datafusion = { git = "http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-datafusion.git", branch="1.0", package = "datafusion" } +datafusion = { path = "../datafusion" } arrow = { version = "7.0.0" } ballista = { path = "../ballista/rust/client", version = "0.6.0" } diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index bc37c7a0de20d..27e0e31a448d5 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -78,6 +78,8 @@ avro-rs = { version = "0.13", features = ["snappy"], optional = true } num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.14", optional = true } tempfile = "3" +libloading = "0.7.3" +rustc_version = "0.4.0" [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 89ccd7b2b938f..9b1b53dca84d4 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -77,10 +77,13 @@ use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use crate::physical_optimizer::repartition::Repartition; use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use crate::execution::udaf_plugin_manager::UDAF_PLUGIN_MANAGER; +use crate::execution::udf_plugin_manager::UDF_PLUGIN_MANAGER; use crate::logical_plan::plan::Explain; use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::physical_plan::planner::DefaultPhysicalPlanner; -use crate::physical_plan::udf::ScalarUDF; +use crate::physical_plan::udaf::UDAFPlugin; +use crate::physical_plan::udf::{ScalarUDF, UDFPlugin}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::PhysicalPlanner; use crate::sql::{ @@ -182,7 +185,7 @@ impl ExecutionContext { let runtime_env = Arc::new(RuntimeEnv::new(config.runtime_config.clone()).unwrap()); - Self { + let mut context = Self { state: Arc::new(Mutex::new(ExecutionContextState { catalog_list, scalar_functions: HashMap::new(), @@ -193,7 +196,42 @@ impl ExecutionContext { object_store_registry: Arc::new(ObjectStoreRegistry::new()), runtime_env, })), - } + }; + + // register udf + UDF_PLUGIN_MANAGER + .plugin_names + .iter() + .for_each(|plugin_name| { + let udf_proxy_option = + UDF_PLUGIN_MANAGER.scalar_udfs.get(plugin_name.as_str()); + if let Some(udf_proxy) = udf_proxy_option { + context.register_udf( + udf_proxy + .get_scalar_udf_by_name(plugin_name.as_str()) + .unwrap(), + ); + } + }); + + // register udaf + UDAF_PLUGIN_MANAGER + .plugin_names + .iter() + .for_each(|plugin_name| { + let udaf_proxy_option = UDAF_PLUGIN_MANAGER + .aggregate_udf_plugins + .get(plugin_name.as_str()); + if let Some(udaf_proxy) = udaf_proxy_option { + context.register_udaf( + udaf_proxy + .get_aggregate_udf_by_name(plugin_name.as_str()) + .unwrap(), + ); + } + }); + + context } /// Creates a dataframe that will execute a SQL query. diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs index ebc7c011970b3..6326235a2a528 100644 --- a/datafusion/src/execution/mod.rs +++ b/datafusion/src/execution/mod.rs @@ -17,9 +17,48 @@ //! DataFusion query execution +use std::fs::DirEntry; +use std::{fs, io}; + pub mod context; pub mod dataframe_impl; pub(crate) mod disk_manager; pub(crate) mod memory_manager; pub mod options; pub mod runtime_env; +pub mod udaf_plugin_manager; +pub mod udf_plugin_manager; + +/// plugin manager trait +pub trait PluginManager { + /// # Safety + /// find plugin file from `plugin_path` and load it . + unsafe fn load(&mut self, plugin_path: String) -> io::Result<()> { + // find library file from udaf_plugin_path + let library_files = fs::read_dir(plugin_path) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + for entry in library_files { + let entry = entry?; + let file_type = entry.file_type()?; + + if !file_type.is_file() { + continue; + } + + if let Some(path) = entry.path().extension() { + if let Some(suffix) = path.to_str() { + if suffix == "dylib" { + self.load_plugin_from_library(&entry)?; + } + } + } + } + + Ok(()) + } + + /// # Safety + /// load plugin from the library `file` . Every different plugins should have different implementations + unsafe fn load_plugin_from_library(&mut self, file: &DirEntry) -> io::Result<()>; +} diff --git a/datafusion/src/execution/udaf_plugin_manager.rs b/datafusion/src/execution/udaf_plugin_manager.rs new file mode 100644 index 0000000000000..4fa10fa877998 --- /dev/null +++ b/datafusion/src/execution/udaf_plugin_manager.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! udaf plugin manager +//! +use crate::physical_plan::udaf::UDAFPluginRegistrar as UDAFPluginRegistrarTrait; +use crate::physical_plan::udaf::{AggregateUDF, UDAFPlugin, UDAFPluginDeclaration}; +use lazy_static::lazy_static; +use libloading::Library; +use std::collections::HashMap; +use std::fs::DirEntry; +use std::io; +use std::sync::Arc; + +use crate::error::Result; +use crate::execution::PluginManager; +use crate::physical_plan::{CORE_VERSION, RUSTC_VERSION}; +lazy_static! { + /// load all udaf plugin + pub static ref UDAF_PLUGIN_MANAGER: UDAFPluginManager = unsafe { + let mut plugin = UDAFPluginManager::default(); + plugin.load("plugin/udaf".to_string()).unwrap(); + plugin + }; +} + +/// UDAFPluginManager +#[derive(Default)] +pub struct UDAFPluginManager { + /// aggregate udf plugins save as udaf_name:UDAFPluginProxy + pub aggregate_udf_plugins: HashMap>, + + /// Every Library need a plugin_name . + pub plugin_names: Vec, + + /// All libraries load from the plugin dir. + pub libraries: Vec>, +} + +impl PluginManager for UDAFPluginManager { + unsafe fn load_plugin_from_library(&mut self, file: &DirEntry) -> io::Result<()> { + // load the library into memory + let library = Library::new(file.path()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + let library = Arc::new(library); + + // get a pointer to the plugin_declaration symbol. + let dec = library + .get::<*mut UDAFPluginDeclaration>(b"udaf_plugin_declaration\0") + .unwrap() + .read(); + + // version checks to prevent accidental ABI incompatibilities + if dec.rustc_version != RUSTC_VERSION.as_str() || dec.core_version != CORE_VERSION + { + return Err(io::Error::new(io::ErrorKind::Other, "Version mismatch")); + } + + let mut registrar = UDAFPluginRegistrar::new(library.clone()); + (dec.register)(&mut registrar); + + // Check for duplicate plugin_name and UDF names + if let Some(udaf_plugin_proxy) = registrar.udaf_plugin_proxy { + if self.plugin_names.contains(&udaf_plugin_proxy.plugin_name) { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "plugin name: {} already exists", + udaf_plugin_proxy.plugin_name + ), + )); + } + + udaf_plugin_proxy + .aggregate_udf_plugin + .udaf_names() + .unwrap() + .iter() + .try_for_each(|udaf_name| { + if self.aggregate_udf_plugins.contains_key(udaf_name) { + Err(io::Error::new( + io::ErrorKind::Other, + format!( + "udaf name: {} already exists in plugin: {}", + udaf_name, udaf_plugin_proxy.plugin_name + ), + )) + } else { + self.aggregate_udf_plugins.insert( + udaf_name.to_string(), + Arc::new(udaf_plugin_proxy.clone()), + ); + Ok(()) + } + })?; + + self.plugin_names.push(udaf_plugin_proxy.plugin_name); + } + + self.libraries.push(library); + Ok(()) + } +} + +/// A proxy object which wraps a [`UDAFPlugin`] and makes sure it can't outlive +/// the library it came from. +#[derive(Clone)] +pub struct UDAFPluginProxy { + /// One UDAFPluginProxy only have one UDAFPlugin + aggregate_udf_plugin: Arc>, + + /// Library + _lib: Arc, + + /// One Library can only have one plugin + plugin_name: String, +} + +impl UDAFPlugin for UDAFPluginProxy { + fn get_aggregate_udf_by_name(&self, fun_name: &str) -> Result { + self.aggregate_udf_plugin + .get_aggregate_udf_by_name(fun_name) + } + + fn udaf_names(&self) -> Result> { + self.aggregate_udf_plugin.udaf_names() + } +} + +/// impl UDAFPluginRegistrarTrait +struct UDAFPluginRegistrar { + udaf_plugin_proxy: Option, + lib: Arc, +} + +impl UDAFPluginRegistrar { + pub fn new(lib: Arc) -> Self { + Self { + udaf_plugin_proxy: None, + lib, + } + } +} + +impl UDAFPluginRegistrarTrait for UDAFPluginRegistrar { + fn register_udaf_plugin(&mut self, plugin_name: &str, plugin: Box) { + let proxy = UDAFPluginProxy { + aggregate_udf_plugin: Arc::new(plugin), + _lib: self.lib.clone(), + plugin_name: plugin_name.to_string(), + }; + + self.udaf_plugin_proxy = Some(proxy); + } +} diff --git a/datafusion/src/execution/udf_plugin_manager.rs b/datafusion/src/execution/udf_plugin_manager.rs new file mode 100644 index 0000000000000..5780ee8e2aed2 --- /dev/null +++ b/datafusion/src/execution/udf_plugin_manager.rs @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! udf plugin manager +//! +use crate::physical_plan::udf::UDFPluginRegistrar as UDFPluginRegistrarTrait; +use crate::physical_plan::udf::{ScalarUDF, UDFPlugin, UDFPluginDeclaration}; +use lazy_static::lazy_static; +use libloading::Library; +use std::collections::HashMap; +use std::fs::DirEntry; +use std::io; +use std::sync::Arc; + +use crate::error::Result; +use crate::execution::PluginManager; +use crate::physical_plan::{CORE_VERSION, RUSTC_VERSION}; +lazy_static! { + /// load all udf plugin + pub static ref UDF_PLUGIN_MANAGER: UDFPluginManager = unsafe { + let mut plugin = UDFPluginManager::default(); + plugin.load("plugin/udf".to_string()).unwrap(); + plugin + }; +} + +/// UDFPluginManager +#[derive(Default)] +pub struct UDFPluginManager { + /// scalar udf plugins save as udaf_name:UDFPluginProxy + pub scalar_udfs: HashMap>, + + /// Every Library need a plugin_name . + pub plugin_names: Vec, + + /// All libraries load from the plugin dir. + pub libraries: Vec>, +} + +impl PluginManager for UDFPluginManager { + unsafe fn load_plugin_from_library(&mut self, file: &DirEntry) -> io::Result<()> { + // load the library into memory + let library = Library::new(file.path()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + let library = Arc::new(library); + + // get a pointer to the plugin_declaration symbol. + let dec = library + .get::<*mut UDFPluginDeclaration>(b"udf_plugin_declaration\0") + .unwrap() + .read(); + + // version checks to prevent accidental ABI incompatibilities + if dec.rustc_version != RUSTC_VERSION.as_str() || dec.core_version != CORE_VERSION + { + return Err(io::Error::new(io::ErrorKind::Other, "Version mismatch")); + } + + let mut registrar = UDFPluginRegistrar::new(library.clone()); + (dec.register)(&mut registrar); + + // Check for duplicate plugin_name and UDF names + if let Some(udf_plugin_proxy) = registrar.udf_plugin_proxy { + if self.plugin_names.contains(&udf_plugin_proxy.plugin_name) { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "plugin name: {} already exists", + udf_plugin_proxy.plugin_name + ), + )); + } + + udf_plugin_proxy + .scalar_udf_plugin + .udf_names() + .unwrap() + .iter() + .try_for_each(|udf_name| { + if self.scalar_udfs.contains_key(udf_name) { + Err(io::Error::new( + io::ErrorKind::Other, + format!( + "udf name: {} already exists in plugin: {}", + udf_name, udf_plugin_proxy.plugin_name + ), + )) + } else { + self.scalar_udfs.insert( + udf_name.to_string(), + Arc::new(udf_plugin_proxy.clone()), + ); + Ok(()) + } + })?; + + self.plugin_names.push(udf_plugin_proxy.plugin_name); + } + + self.libraries.push(library); + Ok(()) + } +} + +/// A proxy object which wraps a [`UDFPlugin`] and makes sure it can't outlive +/// the library it came from. + +#[derive(Clone)] +pub struct UDFPluginProxy { + scalar_udf_plugin: Arc>, + _lib: Arc, + plugin_name: String, +} + +impl UDFPlugin for UDFPluginProxy { + fn get_scalar_udf_by_name(&self, fun_name: &str) -> Result { + self.scalar_udf_plugin.get_scalar_udf_by_name(fun_name) + } + + fn udf_names(&self) -> Result> { + self.scalar_udf_plugin.udf_names() + } +} + +struct UDFPluginRegistrar { + udf_plugin_proxy: Option, + lib: Arc, +} + +impl UDFPluginRegistrar { + pub fn new(lib: Arc) -> Self { + Self { + udf_plugin_proxy: None, + lib, + } + } +} + +impl UDFPluginRegistrarTrait for UDFPluginRegistrar { + fn register_udf_plugin(&mut self, plugin_name: &str, plugin: Box) { + let proxy = UDFPluginProxy { + scalar_udf_plugin: Arc::new(plugin), + _lib: self.lib.clone(), + plugin_name: plugin_name.to_string(), + }; + + self.udf_plugin_proxy = Some(proxy); + } +} diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 836d3994343f2..9340860e75524 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -37,6 +37,8 @@ use arrow::{array::ArrayRef, datatypes::Field}; use async_trait::async_trait; pub use display::DisplayFormatType; use futures::stream::Stream; +use lazy_static::lazy_static; +use rustc_version::version; use std::fmt; use std::fmt::{Debug, Display}; use std::ops::Range; @@ -629,6 +631,14 @@ pub trait Accumulator: Send + Sync + Debug { fn evaluate(&self) -> Result; } +/// CARGO_PKG_VERSION +pub static CORE_VERSION: &str = env!("CARGO_PKG_VERSION"); +lazy_static! { + /// set rustc version to static + pub static ref RUSTC_VERSION: String = { + version().unwrap().to_string() + }; +} pub mod aggregates; pub mod analyze; pub mod array_expressions; diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 974b4a9df764f..1803c01363627 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -39,6 +39,62 @@ use super::{ }; use std::sync::Arc; +/// 定义udaf插件,udaf的定义方需要实现该trait +pub trait UDAFPlugin: Send + Sync + 'static { + /// get a aggregate udf by name + fn get_aggregate_udf_by_name(&self, fun_name: &str) -> Result; + + /// return all udaf names + fn udaf_names(&self) -> Result>; +} + +/// Every plugin need a UDAFPluginDeclaration +#[derive(Copy, Clone)] +pub struct UDAFPluginDeclaration { + /// rustc version of the plugin. The plugin's rustc_version need same as plugin manager. + pub rustc_version: &'static str, + + /// core version of the plugin. The plugin's core_version need same as plugin manager. + pub core_version: &'static str, + + /// `register` is a function which impl UDAFPluginRegistrar. It will be call when plugin load. + pub register: unsafe extern "C" fn(&mut dyn UDAFPluginRegistrar), +} + +/// UDAF Plugin Registrar , Define the functions every udaf plugin need impl +pub trait UDAFPluginRegistrar { + /// The udaf plugin need impl this function + fn register_udaf_plugin(&mut self, plugin_name: &str, function: Box); +} + +/// Declare a aggregate udf plugin's name, type and its constructor. +/// +/// # Notes +/// +/// This works by automatically generating an `extern "C"` function with a +/// pre-defined signature and symbol name. And then generating a UDAFPluginDeclaration. +/// Therefore you will only be able to declare one plugin per library. +#[macro_export] +macro_rules! declare_udaf_plugin { + ($plugin_name:expr, $plugin_type:ty, $constructor:path) => { + #[no_mangle] + pub extern "C" fn register_plugin(registrar: &mut dyn UDAFPluginRegistrar) { + // make sure the constructor is the correct type. + let constructor: fn() -> $plugin_type = $constructor; + let object = constructor(); + registrar.register_udaf_plugin($plugin_name, Box::new(object)); + } + + #[no_mangle] + pub static udaf_plugin_declaration: $crate::UDAFPluginDeclaration = + $crate::UDAFPluginDeclaration { + rustc_version: $crate::RUSTC_VERSION, + core_version: $crate::CORE_VERSION, + register: register_plugin, + }; + }; +} + /// Logical representation of a user-defined aggregate function (UDAF) /// A UDAF is different from a UDF in that it is stateful across batches. #[derive(Clone)] diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 4d8fdd8b57b8f..1774e904816ec 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -27,15 +27,69 @@ use crate::error::Result; use crate::{logical_plan::Expr, physical_plan::PhysicalExpr}; use super::{ - functions::{ - ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature, - }, + functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}, type_coercion::coerce, }; use crate::physical_plan::ColumnarValue; use arrow::record_batch::RecordBatch; use std::sync::Arc; +/// 定义udf插件,udf的定义方需要实现该trait +pub trait UDFPlugin: Send + Sync + 'static { + /// get a ScalarUDF by name + fn get_scalar_udf_by_name(&self, fun_name: &str) -> Result; + + /// return all udf names in the plugin + fn udf_names(&self) -> Result>; +} + +/// Every plugin need a UDFPluginDeclaration +#[derive(Copy, Clone)] +pub struct UDFPluginDeclaration { + /// rustc version of the plugin. The plugin's rustc_version need same as plugin manager. + pub rustc_version: &'static str, + + /// core version of the plugin. The plugin's core_version need same as plugin manager. + pub core_version: &'static str, + + /// `register` is a function which impl UDFPluginRegistrar. It will be call when plugin load. + pub register: unsafe extern "C" fn(&mut dyn UDFPluginRegistrar), +} + +/// UDF Plugin Registrar , Define the functions every udf plugin need impl +pub trait UDFPluginRegistrar { + /// The udf plugin need impl this function + fn register_udf_plugin(&mut self, plugin_name: &str, function: Box); +} + +/// Declare a plugin's name, type and its constructor. +/// +/// # Notes +/// +/// This works by automatically generating an `extern "C"` function with a +/// pre-defined signature and symbol name. And then generating a UDFPluginDeclaration. +/// Therefore you will only be able to declare one plugin per library. +#[macro_export] +macro_rules! declare_udf_plugin { + ($plugin_name:expr, $plugin_type:ty, $constructor:path) => { + #[no_mangle] + pub extern "C" fn register_plugin(registrar: &mut dyn UDFPluginRegistrar) { + // make sure the constructor is the correct type. + let constructor: fn() -> $plugin_type = $constructor; + let object = constructor(); + registrar.register_udf_plugin($plugin_name, Box::new(object)); + } + + #[no_mangle] + pub static udf_plugin_declaration: $crate::UDFPluginDeclaration = + $crate::UDFPluginDeclaration { + rustc_version: $crate::RUSTC_VERSION, + core_version: $crate::CORE_VERSION, + register: register_plugin, + }; + }; +} + /// Logical representation of a UDF. #[derive(Clone)] pub struct ScalarUDF { From 189169144aded07eb8369950984010a9aba25a1c Mon Sep 17 00:00:00 2001 From: gaojun Date: Tue, 15 Feb 2022 12:07:31 +0800 Subject: [PATCH 26/38] =?UTF-8?q?=E4=BC=98=E5=8C=96UDF=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E5=8C=96=E5=AE=8F=E4=BB=A3=E7=A0=81=EF=BC=8C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E4=B8=80=E4=B8=AALibaray=E4=B8=AD=E5=90=8C=E6=97=B6=E5=AE=9A?= =?UTF-8?q?=E4=B9=89UDF=E5=92=8CUDAF=E7=9A=84=E6=8F=92=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datafusion/src/physical_plan/udaf.rs | 4 ++-- datafusion/src/physical_plan/udf.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 1803c01363627..39ed5fd510c1b 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -78,7 +78,7 @@ pub trait UDAFPluginRegistrar { macro_rules! declare_udaf_plugin { ($plugin_name:expr, $plugin_type:ty, $constructor:path) => { #[no_mangle] - pub extern "C" fn register_plugin(registrar: &mut dyn UDAFPluginRegistrar) { + pub extern "C" fn register_udaf_plugin(registrar: &mut dyn UDAFPluginRegistrar) { // make sure the constructor is the correct type. let constructor: fn() -> $plugin_type = $constructor; let object = constructor(); @@ -90,7 +90,7 @@ macro_rules! declare_udaf_plugin { $crate::UDAFPluginDeclaration { rustc_version: $crate::RUSTC_VERSION, core_version: $crate::CORE_VERSION, - register: register_plugin, + register: register_udaf_plugin, }; }; } diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 1774e904816ec..a75bc326ac8a3 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -73,7 +73,7 @@ pub trait UDFPluginRegistrar { macro_rules! declare_udf_plugin { ($plugin_name:expr, $plugin_type:ty, $constructor:path) => { #[no_mangle] - pub extern "C" fn register_plugin(registrar: &mut dyn UDFPluginRegistrar) { + pub extern "C" fn register_udf_plugin(registrar: &mut dyn UDFPluginRegistrar) { // make sure the constructor is the correct type. let constructor: fn() -> $plugin_type = $constructor; let object = constructor(); @@ -85,7 +85,7 @@ macro_rules! declare_udf_plugin { $crate::UDFPluginDeclaration { rustc_version: $crate::RUSTC_VERSION, core_version: $crate::CORE_VERSION, - register: register_plugin, + register: register_udf_plugin, }; }; } From 7401f3078855bd1f88fdee7a79c588e060b89061 Mon Sep 17 00:00:00 2001 From: gaojun Date: Tue, 15 Feb 2022 13:35:29 +0800 Subject: [PATCH 27/38] =?UTF-8?q?=E4=BC=98=E5=8C=96UDF=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E5=8C=96=E4=BB=A3=E7=A0=81=EF=BC=8C=E6=B7=BB=E5=8A=A0build.rs?= =?UTF-8?q?=E5=9C=A8build=E6=97=B6=E8=87=AA=E5=8A=A8=E6=B7=BB=E5=8A=A0RUST?= =?UTF-8?q?C=5FVESION=E5=88=B0=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datafusion/Cargo.toml | 3 +++ datafusion/build.rs | 4 ++++ datafusion/src/execution/udaf_plugin_manager.rs | 3 +-- datafusion/src/execution/udf_plugin_manager.rs | 3 +-- datafusion/src/physical_plan/mod.rs | 10 ++-------- 5 files changed, 11 insertions(+), 12 deletions(-) create mode 100644 datafusion/build.rs diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 27e0e31a448d5..3728dbb4803f3 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -85,6 +85,9 @@ rustc_version = "0.4.0" criterion = "0.3" doc-comment = "0.3" +[build-dependencies] +rustc_version = "0.4.0" + [[bench]] name = "aggregate_query_sql" harness = false diff --git a/datafusion/build.rs b/datafusion/build.rs new file mode 100644 index 0000000000000..a38022ffdaa70 --- /dev/null +++ b/datafusion/build.rs @@ -0,0 +1,4 @@ +fn main() { + let version = rustc_version::version().unwrap(); + println!("cargo:rustc-env=RUSTC_VERSION={}", version); +} diff --git a/datafusion/src/execution/udaf_plugin_manager.rs b/datafusion/src/execution/udaf_plugin_manager.rs index 4fa10fa877998..3cff5b784f758 100644 --- a/datafusion/src/execution/udaf_plugin_manager.rs +++ b/datafusion/src/execution/udaf_plugin_manager.rs @@ -66,8 +66,7 @@ impl PluginManager for UDAFPluginManager { .read(); // version checks to prevent accidental ABI incompatibilities - if dec.rustc_version != RUSTC_VERSION.as_str() || dec.core_version != CORE_VERSION - { + if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { return Err(io::Error::new(io::ErrorKind::Other, "Version mismatch")); } diff --git a/datafusion/src/execution/udf_plugin_manager.rs b/datafusion/src/execution/udf_plugin_manager.rs index 5780ee8e2aed2..496639e4f6540 100644 --- a/datafusion/src/execution/udf_plugin_manager.rs +++ b/datafusion/src/execution/udf_plugin_manager.rs @@ -66,8 +66,7 @@ impl PluginManager for UDFPluginManager { .read(); // version checks to prevent accidental ABI incompatibilities - if dec.rustc_version != RUSTC_VERSION.as_str() || dec.core_version != CORE_VERSION - { + if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { return Err(io::Error::new(io::ErrorKind::Other, "Version mismatch")); } diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 9340860e75524..42f0cf22d0c14 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -37,8 +37,6 @@ use arrow::{array::ArrayRef, datatypes::Field}; use async_trait::async_trait; pub use display::DisplayFormatType; use futures::stream::Stream; -use lazy_static::lazy_static; -use rustc_version::version; use std::fmt; use std::fmt::{Debug, Display}; use std::ops::Range; @@ -633,12 +631,8 @@ pub trait Accumulator: Send + Sync + Debug { /// CARGO_PKG_VERSION pub static CORE_VERSION: &str = env!("CARGO_PKG_VERSION"); -lazy_static! { - /// set rustc version to static - pub static ref RUSTC_VERSION: String = { - version().unwrap().to_string() - }; -} +/// RUSTC_VERSION +pub static RUSTC_VERSION: &str = env!("RUSTC_VERSION"); pub mod aggregates; pub mod analyze; pub mod array_expressions; From cfa2319e256120bc27efc3e2edd4f097d04c6756 Mon Sep 17 00:00:00 2001 From: gaojun Date: Tue, 15 Feb 2022 14:09:50 +0800 Subject: [PATCH 28/38] =?UTF-8?q?=E4=BC=98=E5=8C=96UDF=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E5=8C=96=E4=BB=A3=E7=A0=81=EF=BC=8C=E5=8F=AA=E6=9C=89=E5=8A=A8?= =?UTF-8?q?=E6=80=81=E5=BA=93=E4=B8=AD=E6=9C=89=E6=8F=92=E4=BB=B6=EF=BC=8C?= =?UTF-8?q?=E6=89=8D=E4=BF=9D=E5=AD=98=E8=AF=A5Library?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datafusion/src/execution/udaf_plugin_manager.rs | 2 +- datafusion/src/execution/udf_plugin_manager.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/src/execution/udaf_plugin_manager.rs b/datafusion/src/execution/udaf_plugin_manager.rs index 3cff5b784f758..9f99496cc5ea6 100644 --- a/datafusion/src/execution/udaf_plugin_manager.rs +++ b/datafusion/src/execution/udaf_plugin_manager.rs @@ -109,9 +109,9 @@ impl PluginManager for UDAFPluginManager { })?; self.plugin_names.push(udaf_plugin_proxy.plugin_name); + self.libraries.push(library); } - self.libraries.push(library); Ok(()) } } diff --git a/datafusion/src/execution/udf_plugin_manager.rs b/datafusion/src/execution/udf_plugin_manager.rs index 496639e4f6540..80fd142ce0685 100644 --- a/datafusion/src/execution/udf_plugin_manager.rs +++ b/datafusion/src/execution/udf_plugin_manager.rs @@ -109,9 +109,9 @@ impl PluginManager for UDFPluginManager { })?; self.plugin_names.push(udf_plugin_proxy.plugin_name); + self.libraries.push(library); } - self.libraries.push(library); Ok(()) } } From d0a3bf76f5f2eedfe101a810112da3a57a23052d Mon Sep 17 00:00:00 2001 From: gaojun Date: Tue, 15 Feb 2022 16:04:59 +0800 Subject: [PATCH 29/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E5=AE=8F=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datafusion/src/physical_plan/udaf.rs | 13 ++++++++----- datafusion/src/physical_plan/udf.rs | 13 ++++++++----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 39ed5fd510c1b..66f8b2c96b151 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -78,7 +78,9 @@ pub trait UDAFPluginRegistrar { macro_rules! declare_udaf_plugin { ($plugin_name:expr, $plugin_type:ty, $constructor:path) => { #[no_mangle] - pub extern "C" fn register_udaf_plugin(registrar: &mut dyn UDAFPluginRegistrar) { + pub extern "C" fn register_udaf_plugin( + registrar: &mut dyn $crate::physical_plan::udaf::UDAFPluginRegistrar, + ) { // make sure the constructor is the correct type. let constructor: fn() -> $plugin_type = $constructor; let object = constructor(); @@ -86,10 +88,11 @@ macro_rules! declare_udaf_plugin { } #[no_mangle] - pub static udaf_plugin_declaration: $crate::UDAFPluginDeclaration = - $crate::UDAFPluginDeclaration { - rustc_version: $crate::RUSTC_VERSION, - core_version: $crate::CORE_VERSION, + pub static udaf_plugin_declaration: + $crate::physical_plan::udaf::UDAFPluginDeclaration = + $crate::physical_plan::udaf::UDAFPluginDeclaration { + rustc_version: $crate::physical_plan::RUSTC_VERSION, + core_version: $crate::physical_plan::CORE_VERSION, register: register_udaf_plugin, }; }; diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index a75bc326ac8a3..69580a6cd83b3 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -73,7 +73,9 @@ pub trait UDFPluginRegistrar { macro_rules! declare_udf_plugin { ($plugin_name:expr, $plugin_type:ty, $constructor:path) => { #[no_mangle] - pub extern "C" fn register_udf_plugin(registrar: &mut dyn UDFPluginRegistrar) { + pub extern "C" fn register_udf_plugin( + registrar: &mut dyn $crate::physical_plan::udf::UDFPluginRegistrar, + ) { // make sure the constructor is the correct type. let constructor: fn() -> $plugin_type = $constructor; let object = constructor(); @@ -81,10 +83,11 @@ macro_rules! declare_udf_plugin { } #[no_mangle] - pub static udf_plugin_declaration: $crate::UDFPluginDeclaration = - $crate::UDFPluginDeclaration { - rustc_version: $crate::RUSTC_VERSION, - core_version: $crate::CORE_VERSION, + pub static udf_plugin_declaration: + $crate::physical_plan::udf::UDFPluginDeclaration = + $crate::physical_plan::udf::UDFPluginDeclaration { + rustc_version: $crate::physical_plan::RUSTC_VERSION, + core_version: $crate::physical_plan::CORE_VERSION, register: register_udf_plugin, }; }; From 601181aad17d6d4418e96612f2597c83764e96a9 Mon Sep 17 00:00:00 2001 From: gaojun Date: Wed, 16 Feb 2022 10:27:56 +0800 Subject: [PATCH 30/38] =?UTF-8?q?=E9=87=8D=E6=9E=84plugin=5Fmanager?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=8E=B7=E5=8F=96=E7=B3=BB=E7=BB=9F=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E8=B7=AF=E5=BE=84=E7=9A=84=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/src/serde/logical_plan/from_proto.rs | 2 +- datafusion/src/execution/mod.rs | 38 +--------- datafusion/src/execution/plugin_manager.rs | 76 +++++++++++++++++++ .../src/execution/udaf_plugin_manager.rs | 5 +- .../src/execution/udf_plugin_manager.rs | 5 +- 5 files changed, 84 insertions(+), 42 deletions(-) create mode 100644 datafusion/src/execution/plugin_manager.rs diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 55b70a3eddee9..55d8305a48cf1 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1321,9 +1321,9 @@ impl TryInto for &protobuf::Field { } use crate::serde::protobuf::ColumnStats; +use datafusion::execution::plugin_manager::PluginManager; use datafusion::execution::udaf_plugin_manager::UDAF_PLUGIN_MANAGER; use datafusion::execution::udf_plugin_manager::UDF_PLUGIN_MANAGER; -use datafusion::execution::PluginManager; use datafusion::physical_plan::udaf::{AggregateUDF, UDAFPlugin}; use datafusion::physical_plan::udf::UDFPlugin; use datafusion::physical_plan::{aggregates, windows}; diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs index 6326235a2a528..09d5d6dc570ec 100644 --- a/datafusion/src/execution/mod.rs +++ b/datafusion/src/execution/mod.rs @@ -17,48 +17,12 @@ //! DataFusion query execution -use std::fs::DirEntry; -use std::{fs, io}; - pub mod context; pub mod dataframe_impl; pub(crate) mod disk_manager; pub(crate) mod memory_manager; pub mod options; +pub mod plugin_manager; pub mod runtime_env; pub mod udaf_plugin_manager; pub mod udf_plugin_manager; - -/// plugin manager trait -pub trait PluginManager { - /// # Safety - /// find plugin file from `plugin_path` and load it . - unsafe fn load(&mut self, plugin_path: String) -> io::Result<()> { - // find library file from udaf_plugin_path - let library_files = fs::read_dir(plugin_path) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - - for entry in library_files { - let entry = entry?; - let file_type = entry.file_type()?; - - if !file_type.is_file() { - continue; - } - - if let Some(path) = entry.path().extension() { - if let Some(suffix) = path.to_str() { - if suffix == "dylib" { - self.load_plugin_from_library(&entry)?; - } - } - } - } - - Ok(()) - } - - /// # Safety - /// load plugin from the library `file` . Every different plugins should have different implementations - unsafe fn load_plugin_from_library(&mut self, file: &DirEntry) -> io::Result<()>; -} diff --git a/datafusion/src/execution/plugin_manager.rs b/datafusion/src/execution/plugin_manager.rs new file mode 100644 index 0000000000000..d42bb8efdc506 --- /dev/null +++ b/datafusion/src/execution/plugin_manager.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! plugin manager + +use std::fs::DirEntry; +use std::{env, fs, io}; + +/// plugin manager trait +pub trait PluginManager { + /// # Safety + /// find plugin file from `plugin_path` and load it . + unsafe fn load(&mut self, plugin_path: String) -> io::Result<()> { + // find library file from udaf_plugin_path + let library_files = fs::read_dir(plugin_path) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + for entry in library_files { + let entry = entry?; + let file_type = entry.file_type()?; + + if !file_type.is_file() { + continue; + } + + if let Some(path) = entry.path().extension() { + if let Some(suffix) = path.to_str() { + if suffix == "dylib" { + self.load_plugin_from_library(&entry)?; + } + } + } + } + + Ok(()) + } + + /// # Safety + /// load plugin from the library `file` . Every different plugins should have different implementations + unsafe fn load_plugin_from_library(&mut self, file: &DirEntry) -> io::Result<()>; +} + +/// get the plugin dir +pub fn plugin_dir() -> String { + let current_exe_dir = match env::current_exe() { + Ok(exe_path) => exe_path.display().to_string(), + Err(_e) => "".to_string(), + }; + + // If current_exe_dir contain `deps` the root dir is the parent dir + // eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/deps/plugins_app-067452b3ff2af70e + // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug + // else eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/plugins_app + // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/ + if current_exe_dir.contains("/deps/") { + let i = current_exe_dir.find("/deps/").unwrap(); + String::from(¤t_exe_dir.as_str()[..i]) + } else { + let i = current_exe_dir.rfind('/').unwrap(); + String::from(¤t_exe_dir.as_str()[..i]) + } +} diff --git a/datafusion/src/execution/udaf_plugin_manager.rs b/datafusion/src/execution/udaf_plugin_manager.rs index 9f99496cc5ea6..a56adfd3aacdf 100644 --- a/datafusion/src/execution/udaf_plugin_manager.rs +++ b/datafusion/src/execution/udaf_plugin_manager.rs @@ -27,13 +27,14 @@ use std::io; use std::sync::Arc; use crate::error::Result; -use crate::execution::PluginManager; +use crate::execution::plugin_manager::{plugin_dir, PluginManager}; use crate::physical_plan::{CORE_VERSION, RUSTC_VERSION}; lazy_static! { /// load all udaf plugin pub static ref UDAF_PLUGIN_MANAGER: UDAFPluginManager = unsafe { let mut plugin = UDAFPluginManager::default(); - plugin.load("plugin/udaf".to_string()).unwrap(); + let plugin_path = plugin_dir(); + plugin.load(plugin_path).unwrap(); plugin }; } diff --git a/datafusion/src/execution/udf_plugin_manager.rs b/datafusion/src/execution/udf_plugin_manager.rs index 80fd142ce0685..b316af62b0c40 100644 --- a/datafusion/src/execution/udf_plugin_manager.rs +++ b/datafusion/src/execution/udf_plugin_manager.rs @@ -17,6 +17,7 @@ //! udf plugin manager //! +use crate::execution::plugin_manager::{plugin_dir, PluginManager}; use crate::physical_plan::udf::UDFPluginRegistrar as UDFPluginRegistrarTrait; use crate::physical_plan::udf::{ScalarUDF, UDFPlugin, UDFPluginDeclaration}; use lazy_static::lazy_static; @@ -27,13 +28,13 @@ use std::io; use std::sync::Arc; use crate::error::Result; -use crate::execution::PluginManager; use crate::physical_plan::{CORE_VERSION, RUSTC_VERSION}; lazy_static! { /// load all udf plugin pub static ref UDF_PLUGIN_MANAGER: UDFPluginManager = unsafe { let mut plugin = UDFPluginManager::default(); - plugin.load("plugin/udf".to_string()).unwrap(); + let plugin_path = plugin_dir(); + plugin.load(plugin_path).unwrap(); plugin }; } From 557ac9638eea5896a9e5f9b4fecd49f50510981f Mon Sep 17 00:00:00 2001 From: gaojun Date: Wed, 16 Feb 2022 10:52:05 +0800 Subject: [PATCH 31/38] =?UTF-8?q?=E9=87=8D=E6=9E=84plugin=5Fmanager?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=8E=B7=E5=8F=96=E7=B3=BB=E7=BB=9F=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E8=B7=AF=E5=BE=84=E7=9A=84=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datafusion/src/execution/plugin_manager.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/datafusion/src/execution/plugin_manager.rs b/datafusion/src/execution/plugin_manager.rs index d42bb8efdc506..4939071e5d5e7 100644 --- a/datafusion/src/execution/plugin_manager.rs +++ b/datafusion/src/execution/plugin_manager.rs @@ -17,6 +17,7 @@ //! plugin manager +use log::info; use std::fs::DirEntry; use std::{env, fs, io}; @@ -26,6 +27,8 @@ pub trait PluginManager { /// find plugin file from `plugin_path` and load it . unsafe fn load(&mut self, plugin_path: String) -> io::Result<()> { // find library file from udaf_plugin_path + info!("load plugin from dir:{}", plugin_path); + println!("load plugin from dir:{}", plugin_path); let library_files = fs::read_dir(plugin_path) .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; @@ -40,6 +43,11 @@ pub trait PluginManager { if let Some(path) = entry.path().extension() { if let Some(suffix) = path.to_str() { if suffix == "dylib" { + info!("load plugin from library file:{}", path.to_str().unwrap()); + println!( + "load plugin from library file:{}", + path.to_str().unwrap() + ); self.load_plugin_from_library(&entry)?; } } From 657b90f7ed63c9c472c9870d29d721377216ee34 Mon Sep 17 00:00:00 2001 From: gaojun Date: Wed, 16 Feb 2022 11:16:47 +0800 Subject: [PATCH 32/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8Ddatafusion=20ExecutionC?= =?UTF-8?q?ontext=E6=B2=A1=E6=9C=89=E8=87=AA=E5=8A=A8=E6=B3=A8=E5=86=8Cudf?= =?UTF-8?q?=E5=92=8Cudaf=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/src/serde/logical_plan/from_proto.rs | 2 +- .../src/serde/physical_plan/from_proto.rs | 2 +- datafusion/src/execution/context.rs | 42 +++++++------------ .../src/execution/udaf_plugin_manager.rs | 8 ++-- .../src/execution/udf_plugin_manager.rs | 2 +- 5 files changed, 23 insertions(+), 33 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 55d8305a48cf1..51341544efdd5 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1081,7 +1081,7 @@ impl TryInto for &protobuf::LogicalExprNode { // argo engine add start ExprType::AggregateUdfExpr(expr) => { let fun = UDAF_PLUGIN_MANAGER - .aggregate_udf_plugins.get(expr.fun_name.as_str()).ok_or_else(|| { + .aggregate_udfs.get(expr.fun_name.as_str()).ok_or_else(|| { proto_error(format!( "can not get udaf:{} from UDAF_PLUGIN_MANAGER.aggregate_udf_plugins!", expr.fun_name.to_string() diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 6b93db2e9a692..e006631f9c948 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -322,7 +322,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let name = agg_node.fun_name.as_str(); let udaf_fun_name = &name[0..name.find('(').unwrap()]; let fun = UDAF_PLUGIN_MANAGER - .aggregate_udf_plugins.get(udaf_fun_name).ok_or_else(|| { + .aggregate_udfs.get(udaf_fun_name).ok_or_else(|| { proto_error(format!( "can not get udaf:{} from UDAF_PLUGIN_MANAGER.aggregate_udf_plugins!", udaf_fun_name.to_string() diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 9b1b53dca84d4..24c919b24bf09 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -200,36 +200,26 @@ impl ExecutionContext { // register udf UDF_PLUGIN_MANAGER - .plugin_names + .scalar_udfs .iter() - .for_each(|plugin_name| { - let udf_proxy_option = - UDF_PLUGIN_MANAGER.scalar_udfs.get(plugin_name.as_str()); - if let Some(udf_proxy) = udf_proxy_option { - context.register_udf( - udf_proxy - .get_scalar_udf_by_name(plugin_name.as_str()) - .unwrap(), - ); - } + .for_each(|(udf_name, plugin_proxy)| { + context.register_udf( + plugin_proxy + .get_scalar_udf_by_name(udf_name.as_str()) + .unwrap(), + ) }); // register udaf - UDAF_PLUGIN_MANAGER - .plugin_names - .iter() - .for_each(|plugin_name| { - let udaf_proxy_option = UDAF_PLUGIN_MANAGER - .aggregate_udf_plugins - .get(plugin_name.as_str()); - if let Some(udaf_proxy) = udaf_proxy_option { - context.register_udaf( - udaf_proxy - .get_aggregate_udf_by_name(plugin_name.as_str()) - .unwrap(), - ); - } - }); + UDAF_PLUGIN_MANAGER.aggregate_udfs.iter().for_each( + |(udaf_name, plugin_proxy)| { + context.register_udaf( + plugin_proxy + .get_aggregate_udf_by_name(udaf_name.as_str()) + .unwrap(), + ); + }, + ); context } diff --git a/datafusion/src/execution/udaf_plugin_manager.rs b/datafusion/src/execution/udaf_plugin_manager.rs index a56adfd3aacdf..3a5054d6a20d1 100644 --- a/datafusion/src/execution/udaf_plugin_manager.rs +++ b/datafusion/src/execution/udaf_plugin_manager.rs @@ -42,8 +42,8 @@ lazy_static! { /// UDAFPluginManager #[derive(Default)] pub struct UDAFPluginManager { - /// aggregate udf plugins save as udaf_name:UDAFPluginProxy - pub aggregate_udf_plugins: HashMap>, + /// aggregate udfs save as udaf_name:UDAFPluginProxy + pub aggregate_udfs: HashMap>, /// Every Library need a plugin_name . pub plugin_names: Vec, @@ -92,7 +92,7 @@ impl PluginManager for UDAFPluginManager { .unwrap() .iter() .try_for_each(|udaf_name| { - if self.aggregate_udf_plugins.contains_key(udaf_name) { + if self.aggregate_udfs.contains_key(udaf_name) { Err(io::Error::new( io::ErrorKind::Other, format!( @@ -101,7 +101,7 @@ impl PluginManager for UDAFPluginManager { ), )) } else { - self.aggregate_udf_plugins.insert( + self.aggregate_udfs.insert( udaf_name.to_string(), Arc::new(udaf_plugin_proxy.clone()), ); diff --git a/datafusion/src/execution/udf_plugin_manager.rs b/datafusion/src/execution/udf_plugin_manager.rs index b316af62b0c40..6c7c08dcbaa6d 100644 --- a/datafusion/src/execution/udf_plugin_manager.rs +++ b/datafusion/src/execution/udf_plugin_manager.rs @@ -42,7 +42,7 @@ lazy_static! { /// UDFPluginManager #[derive(Default)] pub struct UDFPluginManager { - /// scalar udf plugins save as udaf_name:UDFPluginProxy + /// scalar udfs save as udaf_name:UDFPluginProxy pub scalar_udfs: HashMap>, /// Every Library need a plugin_name . From e7c02f96d7e7102586b2de8ee648a2705bad005a Mon Sep 17 00:00:00 2001 From: gaojun Date: Wed, 16 Feb 2022 13:19:08 +0800 Subject: [PATCH 33/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8Ddatafusion=20ExecutionC?= =?UTF-8?q?ontext=E6=B2=A1=E6=9C=89=E8=87=AA=E5=8A=A8=E6=B3=A8=E5=86=8Cudf?= =?UTF-8?q?=E5=92=8Cudaf=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datafusion/src/execution/plugin_manager.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/execution/plugin_manager.rs b/datafusion/src/execution/plugin_manager.rs index 4939071e5d5e7..286172317ae3b 100644 --- a/datafusion/src/execution/plugin_manager.rs +++ b/datafusion/src/execution/plugin_manager.rs @@ -42,7 +42,7 @@ pub trait PluginManager { if let Some(path) = entry.path().extension() { if let Some(suffix) = path.to_str() { - if suffix == "dylib" { + if suffix == "dylib" || suffix == "os" || suffix == "dll" { info!("load plugin from library file:{}", path.to_str().unwrap()); println!( "load plugin from library file:{}", From 8a4b8cfb4396f1cf6023f80bd9d219227409ac22 Mon Sep 17 00:00:00 2001 From: gaojun Date: Wed, 16 Feb 2022 14:18:32 +0800 Subject: [PATCH 34/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8Ddatafusion=20ExecutionC?= =?UTF-8?q?ontext=E6=B2=A1=E6=9C=89=E8=87=AA=E5=8A=A8=E6=B3=A8=E5=86=8Cudf?= =?UTF-8?q?=E5=92=8Cudaf=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datafusion/src/execution/plugin_manager.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/execution/plugin_manager.rs b/datafusion/src/execution/plugin_manager.rs index 286172317ae3b..038615e4ec8a3 100644 --- a/datafusion/src/execution/plugin_manager.rs +++ b/datafusion/src/execution/plugin_manager.rs @@ -42,7 +42,7 @@ pub trait PluginManager { if let Some(path) = entry.path().extension() { if let Some(suffix) = path.to_str() { - if suffix == "dylib" || suffix == "os" || suffix == "dll" { + if suffix == "dylib" || suffix == "so" || suffix == "dll" { info!("load plugin from library file:{}", path.to_str().unwrap()); println!( "load plugin from library file:{}", From d9b6d90801dcb36274b3ca7c98ed3f369551e605 Mon Sep 17 00:00:00 2001 From: gaojun Date: Tue, 22 Feb 2022 17:47:26 +0800 Subject: [PATCH 35/38] tmp --- .../core/src/serde/logical_plan/from_proto.rs | 92 +++++----- .../src/serde/physical_plan/from_proto.rs | 101 ++++++----- datafusion/Cargo.toml | 2 + datafusion/src/execution/context.rs | 52 +++--- datafusion/src/execution/mod.rs | 3 - datafusion/src/execution/plugin_manager.rs | 84 --------- .../src/execution/udaf_plugin_manager.rs | 170 ------------------ .../src/execution/udf_plugin_manager.rs | 164 ----------------- datafusion/src/lib.rs | 2 + datafusion/src/physical_plan/mod.rs | 4 - datafusion/src/physical_plan/udaf.rs | 59 ------ datafusion/src/physical_plan/udf.rs | 59 ------ datafusion/src/plugin/mod.rs | 120 +++++++++++++ datafusion/src/plugin/plugin_manager.rs | 131 ++++++++++++++ datafusion/src/plugin/udf.rs | 88 +++++++++ 15 files changed, 475 insertions(+), 656 deletions(-) delete mode 100644 datafusion/src/execution/plugin_manager.rs delete mode 100644 datafusion/src/execution/udaf_plugin_manager.rs delete mode 100644 datafusion/src/execution/udf_plugin_manager.rs create mode 100644 datafusion/src/plugin/mod.rs create mode 100644 datafusion/src/plugin/plugin_manager.rs create mode 100644 datafusion/src/plugin/udf.rs diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 51341544efdd5..d8c6f5ee6d828 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1080,44 +1080,56 @@ impl TryInto for &protobuf::LogicalExprNode { } // argo engine add start ExprType::AggregateUdfExpr(expr) => { - let fun = UDAF_PLUGIN_MANAGER - .aggregate_udfs.get(expr.fun_name.as_str()).ok_or_else(|| { - proto_error(format!( - "can not get udaf:{} from UDAF_PLUGIN_MANAGER.aggregate_udf_plugins!", - expr.fun_name.to_string() - )) - })?; - let fun = fun - .get_aggregate_udf_by_name(expr.fun_name.as_str()) - .map_err(|e| BallistaError::DataFusionError(e))?; - let fun_arc = Arc::new(fun); - let fun_args = &expr.args; - let args: Vec = fun_args - .iter() - .map(|e| e.try_into()) - .collect::, BallistaError>>()?; - Ok(Expr::AggregateUDF { fun: fun_arc, args }) + let gpm = global_plugin_manager("").lock().unwrap(); + let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + let fun = udf_plugin_manager + .aggregate_udfs + .get(expr.fun_name.as_str()) + .ok_or_else(|| { + proto_error(format!( + "can not get udaf:{} from udf_plugins!", + expr.fun_name.to_string() + )) + })?; + let fun_arc = fun.clone(); + let fun_args = &expr.args; + let args: Vec = fun_args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(Expr::AggregateUDF { fun: fun_arc, args }) + } else { + Err(proto_error("can not get udf plugin".to_string())) + } } ExprType::ScalarUdfProtoExpr(expr) => { - let fun = UDF_PLUGIN_MANAGER - .scalar_udfs - .get(expr.fun_name.as_str()) - .ok_or_else(|| { - proto_error(format!( - "can not get udf:{} from UDF_PLUGIN_MANAGER.scalar_udfs!", - expr.fun_name.to_string() - )) - })?; - let fun = fun - .get_scalar_udf_by_name(expr.fun_name.as_str()) - .map_err(|e| BallistaError::DataFusionError(e))?; - let fun_arc = Arc::new(fun); - let fun_args = &expr.args; - let args: Vec = fun_args - .iter() - .map(|e| e.try_into()) - .collect::, BallistaError>>()?; - Ok(Expr::ScalarUDF { fun: fun_arc, args }) + let gpm = global_plugin_manager("").lock().unwrap(); + let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + let fun = udf_plugin_manager + .scalar_udfs + .get(expr.fun_name.as_str()) + .ok_or_else(|| { + proto_error(format!( + "can not get udf:{} from udf_plugins!", + expr.fun_name.to_string() + )) + })?; + let fun_arc = fun.clone(); + let fun_args = &expr.args; + let args: Vec = fun_args + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + Ok(Expr::ScalarUDF { fun: fun_arc, args }) + } else { + Err(proto_error(format!("can not found udf plugins!"))) + } } // argo engine add end ExprType::Alias(alias) => Ok(Expr::Alias( Box::new(parse_required_expr(&alias.expr)?), @@ -1321,12 +1333,10 @@ impl TryInto for &protobuf::Field { } use crate::serde::protobuf::ColumnStats; -use datafusion::execution::plugin_manager::PluginManager; -use datafusion::execution::udaf_plugin_manager::UDAF_PLUGIN_MANAGER; -use datafusion::execution::udf_plugin_manager::UDF_PLUGIN_MANAGER; -use datafusion::physical_plan::udaf::{AggregateUDF, UDAFPlugin}; -use datafusion::physical_plan::udf::UDFPlugin; use datafusion::physical_plan::{aggregates, windows}; +use datafusion::plugin::plugin_manager::global_plugin_manager; +use datafusion::plugin::udf::UDFPluginManager; +use datafusion::plugin::PluginEnum; use datafusion::prelude::{ array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256, sha384, sha512, trim, upper, diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index e006631f9c948..72a0455ea396e 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -43,8 +43,6 @@ use datafusion::execution::context::{ ExecutionConfig, ExecutionContextState, ExecutionProps, }; use datafusion::execution::runtime_env::RuntimeEnv; -use datafusion::execution::udaf_plugin_manager::UDAF_PLUGIN_MANAGER; -use datafusion::execution::udf_plugin_manager::UDF_PLUGIN_MANAGER; use datafusion::logical_plan::{ window_frames::WindowFrame, DFSchema, Expr, JoinConstraint, JoinType, }; @@ -58,10 +56,8 @@ use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::sorts::sort::{SortExec, SortOptions}; -use datafusion::physical_plan::udaf::{ - create_aggregate_expr as create_aggregate_udf_expr, UDAFPlugin, -}; -use datafusion::physical_plan::udf::{create_physical_expr, ScalarUDFExpr, UDFPlugin}; +use datafusion::physical_plan::udaf::create_aggregate_expr as create_aggregate_udf_expr; +use datafusion::physical_plan::udf::{create_physical_expr, ScalarUDFExpr}; use datafusion::physical_plan::window_functions::{ BuiltInWindowFunction, WindowFunction, }; @@ -86,6 +82,9 @@ use datafusion::physical_plan::{ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion::plugin::plugin_manager::global_plugin_manager; +use datafusion::plugin::udf::UDFPluginManager; +use datafusion::plugin::PluginEnum; use datafusion::prelude::CsvReadOptions; use log::debug; use protobuf::physical_expr_node::ExprType; @@ -321,28 +320,33 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ExprType::AggregateUdfExpr(agg_node) => { let name = agg_node.fun_name.as_str(); let udaf_fun_name = &name[0..name.find('(').unwrap()]; - let fun = UDAF_PLUGIN_MANAGER - .aggregate_udfs.get(udaf_fun_name).ok_or_else(|| { - proto_error(format!( - "can not get udaf:{} from UDAF_PLUGIN_MANAGER.aggregate_udf_plugins!", - udaf_fun_name.to_string() - )) - })?; - let fun = fun - .get_aggregate_udf_by_name(udaf_fun_name) - .map_err(|e| BallistaError::DataFusionError(e))?; - - let args: Vec> = agg_node.expr - .iter() - .map(|e| e.try_into()) - .collect::, BallistaError>>()?; - - Ok(create_aggregate_udf_expr( - &fun, - &args, - &physical_schema, - name.to_string(), - )?) + let gpm = global_plugin_manager("").lock().unwrap(); + let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = plugin_registrar.as_any().downcast_ref::() + { + let fun = udf_plugin_manager.aggregate_udfs.get(udaf_fun_name).ok_or_else(|| { + proto_error(format!( + "can not get udaf:{} from plugins!", + udaf_fun_name.to_string() + )) + })?; + let aggregate_udf = &*fun.clone(); + let args: Vec> = agg_node.expr + .iter() + .map(|e| e.try_into()) + .collect::, BallistaError>>()?; + + Ok(create_aggregate_udf_expr( + aggregate_udf, + &args, + &physical_schema, + name.to_string(), + )?) + } else { + Err(proto_error(format!( + "can not found udf plugin!" + ))) + } } // argo engine add end. _ => Err(BallistaError::General( "Invalid aggregate expression for HashAggregateExec" @@ -580,32 +584,37 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { } // argo engine add. ExprType::ScalarUdfProtoExpr(e) => { - let fun = - UDF_PLUGIN_MANAGER + let gpm = global_plugin_manager("").lock().unwrap(); + let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + let fun = udf_plugin_manager .scalar_udfs .get(&e.fun_name) .ok_or_else(|| { proto_error(format!( - "can not get udf:{} from UDF_PLUGIN_MANAGER.scalar_udfs!", + "can not get udf:{} from plugin!", &e.fun_name.to_owned() )) })?; - let fun = fun - .get_scalar_udf_by_name(&e.fun_name.as_str()) - .map_err(|e| BallistaError::DataFusionError(e))?; - - let args = e - .expr - .iter() - .map(|x| x.try_into()) - .collect::, _>>()?; - Arc::new(ScalarUDFExpr::new( - e.fun_name.as_str(), - fun, - args, - &convert_required!(e.return_type)?, - )) + let scalar_udf = &*fun.clone(); + let args = e + .expr + .iter() + .map(|x| x.try_into()) + .collect::, _>>()?; + + Arc::new(ScalarUDFExpr::new( + e.fun_name.as_str(), + scalar_udf.clone(), + args, + &convert_required!(e.return_type)?, + )) + } else { + return Err(proto_error(format!("can not found plugin!"))); + } } ExprType::AggregateUdfExpr(_) => { return Err(BallistaError::General( diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 3728dbb4803f3..344ad948c0cc8 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -80,6 +80,8 @@ pyo3 = { version = "0.14", optional = true } tempfile = "3" libloading = "0.7.3" rustc_version = "0.4.0" +walkdir = "2.3.2" +once_cell = "1.9.0" [dev-dependencies] criterion = "0.3" diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 24c919b24bf09..6e542c7c68043 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -77,15 +77,15 @@ use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use crate::physical_optimizer::repartition::Repartition; use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use crate::execution::udaf_plugin_manager::UDAF_PLUGIN_MANAGER; -use crate::execution::udf_plugin_manager::UDF_PLUGIN_MANAGER; use crate::logical_plan::plan::Explain; use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::physical_plan::planner::DefaultPhysicalPlanner; -use crate::physical_plan::udaf::UDAFPlugin; -use crate::physical_plan::udf::{ScalarUDF, UDFPlugin}; +use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::PhysicalPlanner; +use crate::plugin::plugin_manager::global_plugin_manager; +use crate::plugin::udf::UDFPluginManager; +use crate::plugin::PluginEnum; use crate::sql::{ parser::{DFParser, FileType}, planner::{ContextProvider, SqlToRel}, @@ -191,36 +191,33 @@ impl ExecutionContext { scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), - config, + config: config.clone(), execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), runtime_env, })), }; + let gpm = global_plugin_manager(config.plugin_dir.as_str()); + // register udf - UDF_PLUGIN_MANAGER - .scalar_udfs - .iter() - .for_each(|(udf_name, plugin_proxy)| { - context.register_udf( - plugin_proxy - .get_scalar_udf_by_name(udf_name.as_str()) - .unwrap(), - ) - }); - - // register udaf - UDAF_PLUGIN_MANAGER.aggregate_udfs.iter().for_each( - |(udaf_name, plugin_proxy)| { - context.register_udaf( - plugin_proxy - .get_aggregate_udf_by_name(udaf_name.as_str()) - .unwrap(), - ); - }, - ); + let gpm_guard = gpm.lock().unwrap(); + let plugin_registrar = gpm_guard.plugin_managers.get(&PluginEnum::UDF).unwrap(); + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + udf_plugin_manager + .scalar_udfs + .iter() + .for_each(|(_, scalar_udf)| context.register_udf((**scalar_udf).clone())); + udf_plugin_manager + .aggregate_udfs + .iter() + .for_each(|(_, aggregate_udf)| { + context.register_udaf((**aggregate_udf).clone()) + }); + } context } @@ -930,6 +927,8 @@ pub struct ExecutionConfig { parquet_pruning: bool, /// Runtime configurations such as memory threshold and local disk for spill pub runtime_config: RuntimeConfig, + /// plugin dir + pub plugin_dir: String, } impl Default for ExecutionConfig { @@ -965,6 +964,7 @@ impl Default for ExecutionConfig { repartition_windows: true, parquet_pruning: true, runtime_config: RuntimeConfig::default(), + plugin_dir: "".to_owned(), } } } diff --git a/datafusion/src/execution/mod.rs b/datafusion/src/execution/mod.rs index 09d5d6dc570ec..ebc7c011970b3 100644 --- a/datafusion/src/execution/mod.rs +++ b/datafusion/src/execution/mod.rs @@ -22,7 +22,4 @@ pub mod dataframe_impl; pub(crate) mod disk_manager; pub(crate) mod memory_manager; pub mod options; -pub mod plugin_manager; pub mod runtime_env; -pub mod udaf_plugin_manager; -pub mod udf_plugin_manager; diff --git a/datafusion/src/execution/plugin_manager.rs b/datafusion/src/execution/plugin_manager.rs deleted file mode 100644 index 038615e4ec8a3..0000000000000 --- a/datafusion/src/execution/plugin_manager.rs +++ /dev/null @@ -1,84 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! plugin manager - -use log::info; -use std::fs::DirEntry; -use std::{env, fs, io}; - -/// plugin manager trait -pub trait PluginManager { - /// # Safety - /// find plugin file from `plugin_path` and load it . - unsafe fn load(&mut self, plugin_path: String) -> io::Result<()> { - // find library file from udaf_plugin_path - info!("load plugin from dir:{}", plugin_path); - println!("load plugin from dir:{}", plugin_path); - let library_files = fs::read_dir(plugin_path) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - - for entry in library_files { - let entry = entry?; - let file_type = entry.file_type()?; - - if !file_type.is_file() { - continue; - } - - if let Some(path) = entry.path().extension() { - if let Some(suffix) = path.to_str() { - if suffix == "dylib" || suffix == "so" || suffix == "dll" { - info!("load plugin from library file:{}", path.to_str().unwrap()); - println!( - "load plugin from library file:{}", - path.to_str().unwrap() - ); - self.load_plugin_from_library(&entry)?; - } - } - } - } - - Ok(()) - } - - /// # Safety - /// load plugin from the library `file` . Every different plugins should have different implementations - unsafe fn load_plugin_from_library(&mut self, file: &DirEntry) -> io::Result<()>; -} - -/// get the plugin dir -pub fn plugin_dir() -> String { - let current_exe_dir = match env::current_exe() { - Ok(exe_path) => exe_path.display().to_string(), - Err(_e) => "".to_string(), - }; - - // If current_exe_dir contain `deps` the root dir is the parent dir - // eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/deps/plugins_app-067452b3ff2af70e - // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug - // else eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/plugins_app - // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/ - if current_exe_dir.contains("/deps/") { - let i = current_exe_dir.find("/deps/").unwrap(); - String::from(¤t_exe_dir.as_str()[..i]) - } else { - let i = current_exe_dir.rfind('/').unwrap(); - String::from(¤t_exe_dir.as_str()[..i]) - } -} diff --git a/datafusion/src/execution/udaf_plugin_manager.rs b/datafusion/src/execution/udaf_plugin_manager.rs deleted file mode 100644 index 3a5054d6a20d1..0000000000000 --- a/datafusion/src/execution/udaf_plugin_manager.rs +++ /dev/null @@ -1,170 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! udaf plugin manager -//! -use crate::physical_plan::udaf::UDAFPluginRegistrar as UDAFPluginRegistrarTrait; -use crate::physical_plan::udaf::{AggregateUDF, UDAFPlugin, UDAFPluginDeclaration}; -use lazy_static::lazy_static; -use libloading::Library; -use std::collections::HashMap; -use std::fs::DirEntry; -use std::io; -use std::sync::Arc; - -use crate::error::Result; -use crate::execution::plugin_manager::{plugin_dir, PluginManager}; -use crate::physical_plan::{CORE_VERSION, RUSTC_VERSION}; -lazy_static! { - /// load all udaf plugin - pub static ref UDAF_PLUGIN_MANAGER: UDAFPluginManager = unsafe { - let mut plugin = UDAFPluginManager::default(); - let plugin_path = plugin_dir(); - plugin.load(plugin_path).unwrap(); - plugin - }; -} - -/// UDAFPluginManager -#[derive(Default)] -pub struct UDAFPluginManager { - /// aggregate udfs save as udaf_name:UDAFPluginProxy - pub aggregate_udfs: HashMap>, - - /// Every Library need a plugin_name . - pub plugin_names: Vec, - - /// All libraries load from the plugin dir. - pub libraries: Vec>, -} - -impl PluginManager for UDAFPluginManager { - unsafe fn load_plugin_from_library(&mut self, file: &DirEntry) -> io::Result<()> { - // load the library into memory - let library = Library::new(file.path()) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - - let library = Arc::new(library); - - // get a pointer to the plugin_declaration symbol. - let dec = library - .get::<*mut UDAFPluginDeclaration>(b"udaf_plugin_declaration\0") - .unwrap() - .read(); - - // version checks to prevent accidental ABI incompatibilities - if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { - return Err(io::Error::new(io::ErrorKind::Other, "Version mismatch")); - } - - let mut registrar = UDAFPluginRegistrar::new(library.clone()); - (dec.register)(&mut registrar); - - // Check for duplicate plugin_name and UDF names - if let Some(udaf_plugin_proxy) = registrar.udaf_plugin_proxy { - if self.plugin_names.contains(&udaf_plugin_proxy.plugin_name) { - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "plugin name: {} already exists", - udaf_plugin_proxy.plugin_name - ), - )); - } - - udaf_plugin_proxy - .aggregate_udf_plugin - .udaf_names() - .unwrap() - .iter() - .try_for_each(|udaf_name| { - if self.aggregate_udfs.contains_key(udaf_name) { - Err(io::Error::new( - io::ErrorKind::Other, - format!( - "udaf name: {} already exists in plugin: {}", - udaf_name, udaf_plugin_proxy.plugin_name - ), - )) - } else { - self.aggregate_udfs.insert( - udaf_name.to_string(), - Arc::new(udaf_plugin_proxy.clone()), - ); - Ok(()) - } - })?; - - self.plugin_names.push(udaf_plugin_proxy.plugin_name); - self.libraries.push(library); - } - - Ok(()) - } -} - -/// A proxy object which wraps a [`UDAFPlugin`] and makes sure it can't outlive -/// the library it came from. -#[derive(Clone)] -pub struct UDAFPluginProxy { - /// One UDAFPluginProxy only have one UDAFPlugin - aggregate_udf_plugin: Arc>, - - /// Library - _lib: Arc, - - /// One Library can only have one plugin - plugin_name: String, -} - -impl UDAFPlugin for UDAFPluginProxy { - fn get_aggregate_udf_by_name(&self, fun_name: &str) -> Result { - self.aggregate_udf_plugin - .get_aggregate_udf_by_name(fun_name) - } - - fn udaf_names(&self) -> Result> { - self.aggregate_udf_plugin.udaf_names() - } -} - -/// impl UDAFPluginRegistrarTrait -struct UDAFPluginRegistrar { - udaf_plugin_proxy: Option, - lib: Arc, -} - -impl UDAFPluginRegistrar { - pub fn new(lib: Arc) -> Self { - Self { - udaf_plugin_proxy: None, - lib, - } - } -} - -impl UDAFPluginRegistrarTrait for UDAFPluginRegistrar { - fn register_udaf_plugin(&mut self, plugin_name: &str, plugin: Box) { - let proxy = UDAFPluginProxy { - aggregate_udf_plugin: Arc::new(plugin), - _lib: self.lib.clone(), - plugin_name: plugin_name.to_string(), - }; - - self.udaf_plugin_proxy = Some(proxy); - } -} diff --git a/datafusion/src/execution/udf_plugin_manager.rs b/datafusion/src/execution/udf_plugin_manager.rs deleted file mode 100644 index 6c7c08dcbaa6d..0000000000000 --- a/datafusion/src/execution/udf_plugin_manager.rs +++ /dev/null @@ -1,164 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! udf plugin manager -//! -use crate::execution::plugin_manager::{plugin_dir, PluginManager}; -use crate::physical_plan::udf::UDFPluginRegistrar as UDFPluginRegistrarTrait; -use crate::physical_plan::udf::{ScalarUDF, UDFPlugin, UDFPluginDeclaration}; -use lazy_static::lazy_static; -use libloading::Library; -use std::collections::HashMap; -use std::fs::DirEntry; -use std::io; -use std::sync::Arc; - -use crate::error::Result; -use crate::physical_plan::{CORE_VERSION, RUSTC_VERSION}; -lazy_static! { - /// load all udf plugin - pub static ref UDF_PLUGIN_MANAGER: UDFPluginManager = unsafe { - let mut plugin = UDFPluginManager::default(); - let plugin_path = plugin_dir(); - plugin.load(plugin_path).unwrap(); - plugin - }; -} - -/// UDFPluginManager -#[derive(Default)] -pub struct UDFPluginManager { - /// scalar udfs save as udaf_name:UDFPluginProxy - pub scalar_udfs: HashMap>, - - /// Every Library need a plugin_name . - pub plugin_names: Vec, - - /// All libraries load from the plugin dir. - pub libraries: Vec>, -} - -impl PluginManager for UDFPluginManager { - unsafe fn load_plugin_from_library(&mut self, file: &DirEntry) -> io::Result<()> { - // load the library into memory - let library = Library::new(file.path()) - .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; - - let library = Arc::new(library); - - // get a pointer to the plugin_declaration symbol. - let dec = library - .get::<*mut UDFPluginDeclaration>(b"udf_plugin_declaration\0") - .unwrap() - .read(); - - // version checks to prevent accidental ABI incompatibilities - if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { - return Err(io::Error::new(io::ErrorKind::Other, "Version mismatch")); - } - - let mut registrar = UDFPluginRegistrar::new(library.clone()); - (dec.register)(&mut registrar); - - // Check for duplicate plugin_name and UDF names - if let Some(udf_plugin_proxy) = registrar.udf_plugin_proxy { - if self.plugin_names.contains(&udf_plugin_proxy.plugin_name) { - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "plugin name: {} already exists", - udf_plugin_proxy.plugin_name - ), - )); - } - - udf_plugin_proxy - .scalar_udf_plugin - .udf_names() - .unwrap() - .iter() - .try_for_each(|udf_name| { - if self.scalar_udfs.contains_key(udf_name) { - Err(io::Error::new( - io::ErrorKind::Other, - format!( - "udf name: {} already exists in plugin: {}", - udf_name, udf_plugin_proxy.plugin_name - ), - )) - } else { - self.scalar_udfs.insert( - udf_name.to_string(), - Arc::new(udf_plugin_proxy.clone()), - ); - Ok(()) - } - })?; - - self.plugin_names.push(udf_plugin_proxy.plugin_name); - self.libraries.push(library); - } - - Ok(()) - } -} - -/// A proxy object which wraps a [`UDFPlugin`] and makes sure it can't outlive -/// the library it came from. - -#[derive(Clone)] -pub struct UDFPluginProxy { - scalar_udf_plugin: Arc>, - _lib: Arc, - plugin_name: String, -} - -impl UDFPlugin for UDFPluginProxy { - fn get_scalar_udf_by_name(&self, fun_name: &str) -> Result { - self.scalar_udf_plugin.get_scalar_udf_by_name(fun_name) - } - - fn udf_names(&self) -> Result> { - self.scalar_udf_plugin.udf_names() - } -} - -struct UDFPluginRegistrar { - udf_plugin_proxy: Option, - lib: Arc, -} - -impl UDFPluginRegistrar { - pub fn new(lib: Arc) -> Self { - Self { - udf_plugin_proxy: None, - lib, - } - } -} - -impl UDFPluginRegistrarTrait for UDFPluginRegistrar { - fn register_udf_plugin(&mut self, plugin_name: &str, plugin: Box) { - let proxy = UDFPluginProxy { - scalar_udf_plugin: Arc::new(plugin), - _lib: self.lib.clone(), - plugin_name: plugin_name.to_string(), - }; - - self.udf_plugin_proxy = Some(proxy); - } -} diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index fd574d7d76aee..1d3c5250850bf 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -213,6 +213,8 @@ pub mod logical_plan; pub mod optimizer; pub mod physical_optimizer; pub mod physical_plan; +/// plugin mod +pub mod plugin; pub mod prelude; pub mod scalar; pub mod sql; diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 42f0cf22d0c14..836d3994343f2 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -629,10 +629,6 @@ pub trait Accumulator: Send + Sync + Debug { fn evaluate(&self) -> Result; } -/// CARGO_PKG_VERSION -pub static CORE_VERSION: &str = env!("CARGO_PKG_VERSION"); -/// RUSTC_VERSION -pub static RUSTC_VERSION: &str = env!("RUSTC_VERSION"); pub mod aggregates; pub mod analyze; pub mod array_expressions; diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 66f8b2c96b151..974b4a9df764f 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -39,65 +39,6 @@ use super::{ }; use std::sync::Arc; -/// 定义udaf插件,udaf的定义方需要实现该trait -pub trait UDAFPlugin: Send + Sync + 'static { - /// get a aggregate udf by name - fn get_aggregate_udf_by_name(&self, fun_name: &str) -> Result; - - /// return all udaf names - fn udaf_names(&self) -> Result>; -} - -/// Every plugin need a UDAFPluginDeclaration -#[derive(Copy, Clone)] -pub struct UDAFPluginDeclaration { - /// rustc version of the plugin. The plugin's rustc_version need same as plugin manager. - pub rustc_version: &'static str, - - /// core version of the plugin. The plugin's core_version need same as plugin manager. - pub core_version: &'static str, - - /// `register` is a function which impl UDAFPluginRegistrar. It will be call when plugin load. - pub register: unsafe extern "C" fn(&mut dyn UDAFPluginRegistrar), -} - -/// UDAF Plugin Registrar , Define the functions every udaf plugin need impl -pub trait UDAFPluginRegistrar { - /// The udaf plugin need impl this function - fn register_udaf_plugin(&mut self, plugin_name: &str, function: Box); -} - -/// Declare a aggregate udf plugin's name, type and its constructor. -/// -/// # Notes -/// -/// This works by automatically generating an `extern "C"` function with a -/// pre-defined signature and symbol name. And then generating a UDAFPluginDeclaration. -/// Therefore you will only be able to declare one plugin per library. -#[macro_export] -macro_rules! declare_udaf_plugin { - ($plugin_name:expr, $plugin_type:ty, $constructor:path) => { - #[no_mangle] - pub extern "C" fn register_udaf_plugin( - registrar: &mut dyn $crate::physical_plan::udaf::UDAFPluginRegistrar, - ) { - // make sure the constructor is the correct type. - let constructor: fn() -> $plugin_type = $constructor; - let object = constructor(); - registrar.register_udaf_plugin($plugin_name, Box::new(object)); - } - - #[no_mangle] - pub static udaf_plugin_declaration: - $crate::physical_plan::udaf::UDAFPluginDeclaration = - $crate::physical_plan::udaf::UDAFPluginDeclaration { - rustc_version: $crate::physical_plan::RUSTC_VERSION, - core_version: $crate::physical_plan::CORE_VERSION, - register: register_udaf_plugin, - }; - }; -} - /// Logical representation of a user-defined aggregate function (UDAF) /// A UDAF is different from a UDF in that it is stateful across batches. #[derive(Clone)] diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 69580a6cd83b3..55c37dea3374e 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -34,65 +34,6 @@ use crate::physical_plan::ColumnarValue; use arrow::record_batch::RecordBatch; use std::sync::Arc; -/// 定义udf插件,udf的定义方需要实现该trait -pub trait UDFPlugin: Send + Sync + 'static { - /// get a ScalarUDF by name - fn get_scalar_udf_by_name(&self, fun_name: &str) -> Result; - - /// return all udf names in the plugin - fn udf_names(&self) -> Result>; -} - -/// Every plugin need a UDFPluginDeclaration -#[derive(Copy, Clone)] -pub struct UDFPluginDeclaration { - /// rustc version of the plugin. The plugin's rustc_version need same as plugin manager. - pub rustc_version: &'static str, - - /// core version of the plugin. The plugin's core_version need same as plugin manager. - pub core_version: &'static str, - - /// `register` is a function which impl UDFPluginRegistrar. It will be call when plugin load. - pub register: unsafe extern "C" fn(&mut dyn UDFPluginRegistrar), -} - -/// UDF Plugin Registrar , Define the functions every udf plugin need impl -pub trait UDFPluginRegistrar { - /// The udf plugin need impl this function - fn register_udf_plugin(&mut self, plugin_name: &str, function: Box); -} - -/// Declare a plugin's name, type and its constructor. -/// -/// # Notes -/// -/// This works by automatically generating an `extern "C"` function with a -/// pre-defined signature and symbol name. And then generating a UDFPluginDeclaration. -/// Therefore you will only be able to declare one plugin per library. -#[macro_export] -macro_rules! declare_udf_plugin { - ($plugin_name:expr, $plugin_type:ty, $constructor:path) => { - #[no_mangle] - pub extern "C" fn register_udf_plugin( - registrar: &mut dyn $crate::physical_plan::udf::UDFPluginRegistrar, - ) { - // make sure the constructor is the correct type. - let constructor: fn() -> $plugin_type = $constructor; - let object = constructor(); - registrar.register_udf_plugin($plugin_name, Box::new(object)); - } - - #[no_mangle] - pub static udf_plugin_declaration: - $crate::physical_plan::udf::UDFPluginDeclaration = - $crate::physical_plan::udf::UDFPluginDeclaration { - rustc_version: $crate::physical_plan::RUSTC_VERSION, - core_version: $crate::physical_plan::CORE_VERSION, - register: register_udf_plugin, - }; - }; -} - /// Logical representation of a UDF. #[derive(Clone)] pub struct ScalarUDF { diff --git a/datafusion/src/plugin/mod.rs b/datafusion/src/plugin/mod.rs new file mode 100644 index 0000000000000..f6f554433ec95 --- /dev/null +++ b/datafusion/src/plugin/mod.rs @@ -0,0 +1,120 @@ +use crate::error::Result; +use crate::plugin::udf::UDFPluginManager; +use std::any::Any; +use std::env; + +/// plugin manager +pub mod plugin_manager; +/// udf plugin +pub mod udf; + +/// CARGO_PKG_VERSION +pub static CORE_VERSION: &str = env!("CARGO_PKG_VERSION"); +/// RUSTC_VERSION +pub static RUSTC_VERSION: &str = env!("RUSTC_VERSION"); + +/// Top plugin trait +pub trait Plugin { + /// Returns the plugin as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + +/// The enum of Plugin +#[derive(PartialEq, std::cmp::Eq, std::hash::Hash, Copy, Clone)] +pub enum PluginEnum { + /// UDF/UDAF plugin + UDF, +} + +impl PluginEnum { + /// new a struct which impl the PluginRegistrar trait + pub fn init_plugin_manager(&self) -> Box { + match self { + PluginEnum::UDF => Box::new(UDFPluginManager::default()), + } + } +} + +/// Every plugin need a PluginDeclaration +#[derive(Copy, Clone)] +pub struct PluginDeclaration { + /// rustc version of the plugin. The plugin's rustc_version need same as plugin manager. + pub rustc_version: &'static str, + + /// core version of the plugin. The plugin's core_version need same as plugin manager. + pub core_version: &'static str, + + /// One of PluginEnum + pub plugin_type: unsafe extern "C" fn() -> PluginEnum, + + /// `register` is a function which impl PluginRegistrar. It will be call when plugin load. + pub register: unsafe extern "C" fn(&mut Box), +} + +/// Plugin Registrar , Every plugin need implement this trait +pub trait PluginRegistrar: Send + Sync + 'static { + /// The implementer of the plug-in needs to call this interface to report his own information to the plug-in manager + fn register_plugin(&mut self, plugin: Box) -> Result<()>; + + /// Returns the plugin registrar as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} + +/// Declare a plugin's PluginDeclaration. +/// +/// # Notes +/// +/// This works by automatically generating an `extern "C"` function with a +/// pre-defined signature and symbol name. And then generating a PluginDeclaration. +/// Therefore you will only be able to declare one plugin per library. +#[macro_export] +macro_rules! declare_plugin { + ($plugin_type:expr, $curr_plugin_type:ty, $constructor:path) => { + #[no_mangle] + pub extern "C" fn register_plugin( + registrar: &mut Box, + ) { + // make sure the constructor is the correct type. + let constructor: fn() -> $curr_plugin_type = $constructor; + let object = constructor(); + registrar.register_plugin(Box::new(object)); + } + + #[no_mangle] + pub extern "C" fn get_plugin_type() -> $crate::plugin::PluginEnum { + $plugin_type + } + + #[no_mangle] + pub static plugin_declaration: $crate::plugin::PluginDeclaration = + $crate::plugin::PluginDeclaration { + rustc_version: $crate::plugin::RUSTC_VERSION, + core_version: $crate::plugin::CORE_VERSION, + plugin_type: get_plugin_type, + register: register_plugin, + }; + }; +} + +/// get the plugin dir +pub fn plugin_dir() -> String { + let current_exe_dir = match env::current_exe() { + Ok(exe_path) => exe_path.display().to_string(), + Err(_e) => "".to_string(), + }; + + // If current_exe_dir contain `deps` the root dir is the parent dir + // eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/deps/plugins_app-067452b3ff2af70e + // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug + // else eg: /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/plugins_app + // the plugin dir is /Users/xxx/workspace/rust/rust_plugin_sty/target/debug/ + if current_exe_dir.contains("/deps/") { + let i = current_exe_dir.find("/deps/").unwrap(); + String::from(¤t_exe_dir.as_str()[..i]) + } else { + let i = current_exe_dir.rfind('/').unwrap(); + String::from(¤t_exe_dir.as_str()[..i]) + } +} diff --git a/datafusion/src/plugin/plugin_manager.rs b/datafusion/src/plugin/plugin_manager.rs new file mode 100644 index 0000000000000..a8a19e4ac8d9a --- /dev/null +++ b/datafusion/src/plugin/plugin_manager.rs @@ -0,0 +1,131 @@ +use crate::error::{DataFusionError, Result}; +use crate::plugin::{PluginDeclaration, CORE_VERSION, RUSTC_VERSION}; +use crate::plugin::{PluginEnum, PluginRegistrar}; +use libloading::Library; +use log::info; +use std::collections::HashMap; +use std::io; +use std::sync::{Arc, Mutex}; +use walkdir::{DirEntry, WalkDir}; + +use once_cell::sync::OnceCell; + +/// To prevent the library from being loaded multiple times, we use once_cell defines a Arc> +/// Because datafusion is a library, not a service, users may not need to load all plug-ins in the process. +/// So fn global_plugin_manager return Arc>. In this way, users can load the required library through the load method of GlobalPluginManager when needed +pub fn global_plugin_manager( + plugin_path: &str, +) -> &'static Arc> { + static INSTANCE: OnceCell>> = OnceCell::new(); + INSTANCE.get_or_init(move || unsafe { + let mut gpm = GlobalPluginManager::default(); + gpm.load(plugin_path).unwrap(); + Arc::new(Mutex::new(gpm)) + }) +} + +#[derive(Default)] +/// manager all plugin_type's plugin_manager +pub struct GlobalPluginManager { + /// every plugin need a plugin registrar + pub plugin_managers: HashMap>, + + /// loaded plugin files + pub plugin_files: Vec, +} + +impl GlobalPluginManager { + /// # Safety + /// find plugin file from `plugin_path` and load it . + unsafe fn load(&mut self, plugin_path: &str) -> Result<()> { + // find library file from udaf_plugin_path + info!("load plugin from dir:{}", plugin_path); + println!("load plugin from dir:{}", plugin_path); + + let plugin_files = self.get_all_plugin_files(plugin_path)?; + + for plugin_file in plugin_files { + let library = Library::new(plugin_file.path()).map_err(|e| { + DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("load library error: {}", e), + )) + })?; + + let library = Arc::new(library); + + // get a pointer to the plugin_declaration symbol. + let dec = library + .get::<*mut PluginDeclaration>(b"plugin_declaration\0") + .map_err(|e| { + DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("not found plugin_declaration in the library: {}", e), + )) + })? + .read(); + + // version checks to prevent accidental ABI incompatibilities + if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { + return Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + "Version mismatch", + ))); + } + + let plugin_enum = (dec.plugin_type)(); + let curr_plugin_manager = match self.plugin_managers.get_mut(&plugin_enum) { + None => { + let plugin_manager = plugin_enum.init_plugin_manager(); + self.plugin_managers.insert(plugin_enum, plugin_manager); + self.plugin_managers.get_mut(&plugin_enum).unwrap() + } + Some(manager) => manager, + }; + + (dec.register)(curr_plugin_manager); + self.plugin_files + .push(plugin_file.path().to_str().unwrap().to_string()); + } + + Ok(()) + } + + /// get all plugin file in the dir + fn get_all_plugin_files(&self, plugin_path: &str) -> io::Result> { + let mut plugin_files = Vec::new(); + for entry in WalkDir::new(plugin_path).into_iter().filter_map(|e| { + let item = e.unwrap(); + // every file only load once + if self + .plugin_files + .contains(&item.path().to_str().unwrap().to_string()) + { + return None; + } + + let file_type = item.file_type(); + if !file_type.is_file() { + return None; + } + + if let Some(path) = item.path().extension() { + if let Some(suffix) = path.to_str() { + if suffix == "dylib" || suffix == "so" || suffix == "dll" { + info!("load plugin from library file:{}", path.to_str().unwrap()); + println!( + "load plugin from library file:{}", + path.to_str().unwrap() + ); + return Some(item); + } + } + } + + return None; + }) { + plugin_files.push(entry); + } + Ok(plugin_files) + } +} diff --git a/datafusion/src/plugin/udf.rs b/datafusion/src/plugin/udf.rs new file mode 100644 index 0000000000000..ffbb928fbd0f3 --- /dev/null +++ b/datafusion/src/plugin/udf.rs @@ -0,0 +1,88 @@ +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::udaf::AggregateUDF; +use crate::physical_plan::udf::ScalarUDF; +use crate::plugin::{Plugin, PluginRegistrar}; +use libloading::Library; +use std::any::Any; +use std::collections::HashMap; +use std::io; +use std::sync::Arc; + +/// 定义udf插件,udf的定义方需要实现该trait +pub trait UDFPlugin: Plugin { + /// get a ScalarUDF by name + fn get_scalar_udf_by_name(&self, fun_name: &str) -> Result; + + /// return all udf names in the plugin + fn udf_names(&self) -> Result>; + + /// get a aggregate udf by name + fn get_aggregate_udf_by_name(&self, fun_name: &str) -> Result; + + /// return all udaf names + fn udaf_names(&self) -> Result>; +} + +/// UDFPluginManager +#[derive(Default)] +pub struct UDFPluginManager { + /// scalar udfs + pub scalar_udfs: HashMap>, + + /// aggregate udfs + pub aggregate_udfs: HashMap>, + + /// All libraries load from the plugin dir. + pub libraries: Vec>, +} + +impl PluginRegistrar for UDFPluginManager { + fn register_plugin(&mut self, plugin: Box) -> Result<()> { + if let Some(udf_plugin) = plugin.as_any().downcast_ref::>() { + udf_plugin + .udf_names() + .unwrap() + .iter() + .try_for_each(|udf_name| { + if self.scalar_udfs.contains_key(udf_name) { + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udf name: {} already exists", udf_name), + ))) + } else { + let scalar_udf = udf_plugin.get_scalar_udf_by_name(udf_name)?; + self.scalar_udfs + .insert(udf_name.to_string(), Arc::new(scalar_udf)); + Ok(()) + } + })?; + + udf_plugin + .udaf_names() + .unwrap() + .iter() + .try_for_each(|udaf_name| { + if self.aggregate_udfs.contains_key(udaf_name) { + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udaf name: {} already exists", udaf_name), + ))) + } else { + let aggregate_udf = + udf_plugin.get_aggregate_udf_by_name(udaf_name)?; + self.aggregate_udfs + .insert(udaf_name.to_string(), Arc::new(aggregate_udf)); + Ok(()) + } + })?; + } + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("expected plugin type is 'dyn UDFPlugin', but it's not"), + ))) + } + + fn as_any(&self) -> &dyn Any { + self + } +} From 8be0ce1e4725f20a47d8486de61a6ceb99cc5957 Mon Sep 17 00:00:00 2001 From: gaojun Date: Wed, 23 Feb 2022 12:00:11 +0800 Subject: [PATCH 36/38] update submit --- .gitmodules | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitmodules b/.gitmodules index 3670814a9bd9c..5d0594c0c75fa 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "parquet-testing"] path = parquet-testing - url = http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/parquet-testing.git + url = https://github.com/apache/parquet-testing.git [submodule "testing"] path = testing - url = http://argo_engine:hQTHSm845HEDA8Cs_9dk@git.analysysdata.com/noah/arrow-testing.git + url = https://github.com/apache/arrow-testing \ No newline at end of file From 5b3ca32321cebb16cb4858338640722b55097de6 Mon Sep 17 00:00:00 2001 From: gaojun Date: Wed, 23 Feb 2022 15:02:38 +0800 Subject: [PATCH 37/38] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E5=AE=8F=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- datafusion/src/plugin/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/src/plugin/mod.rs b/datafusion/src/plugin/mod.rs index f6f554433ec95..67d6655a2b07c 100644 --- a/datafusion/src/plugin/mod.rs +++ b/datafusion/src/plugin/mod.rs @@ -79,7 +79,7 @@ macro_rules! declare_plugin { // make sure the constructor is the correct type. let constructor: fn() -> $curr_plugin_type = $constructor; let object = constructor(); - registrar.register_plugin(Box::new(object)); + registrar.register_plugin(Box::new(object)).unwrap(); } #[no_mangle] From c71de9ceef562b5b8b3e3022aec689a0519d5d43 Mon Sep 17 00:00:00 2001 From: gaojun Date: Thu, 24 Feb 2022 16:18:40 +0800 Subject: [PATCH 38/38] =?UTF-8?q?udf=E6=8F=92=E4=BB=B6=E5=8C=96=E6=8F=90?= =?UTF-8?q?=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/src/serde/logical_plan/from_proto.rs | 14 +- .../src/serde/physical_plan/from_proto.rs | 13 +- datafusion/src/execution/context.rs | 10 +- datafusion/src/plugin/mod.rs | 44 ++--- datafusion/src/plugin/plugin_manager.rs | 60 ++++--- datafusion/src/plugin/udf.rs | 153 ++++++++++++------ 6 files changed, 183 insertions(+), 111 deletions(-) diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index d8c6f5ee6d828..389c1bdb722e6 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -1080,11 +1080,7 @@ impl TryInto for &protobuf::LogicalExprNode { } // argo engine add start ExprType::AggregateUdfExpr(expr) => { - let gpm = global_plugin_manager("").lock().unwrap(); - let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); - if let Some(udf_plugin_manager) = - plugin_registrar.as_any().downcast_ref::() - { + if let Some(udf_plugin_manager) = get_udf_plugin_manager("") { let fun = udf_plugin_manager .aggregate_udfs .get(expr.fun_name.as_str()) @@ -1106,11 +1102,7 @@ impl TryInto for &protobuf::LogicalExprNode { } } ExprType::ScalarUdfProtoExpr(expr) => { - let gpm = global_plugin_manager("").lock().unwrap(); - let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); - if let Some(udf_plugin_manager) = - plugin_registrar.as_any().downcast_ref::() - { + if let Some(udf_plugin_manager) = get_udf_plugin_manager("") { let fun = udf_plugin_manager .scalar_udfs .get(expr.fun_name.as_str()) @@ -1335,7 +1327,7 @@ impl TryInto for &protobuf::Field { use crate::serde::protobuf::ColumnStats; use datafusion::physical_plan::{aggregates, windows}; use datafusion::plugin::plugin_manager::global_plugin_manager; -use datafusion::plugin::udf::UDFPluginManager; +use datafusion::plugin::udf::{get_udf_plugin_manager, UDFPluginManager}; use datafusion::plugin::PluginEnum; use datafusion::prelude::{ array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256, diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 72a0455ea396e..38c04ef29dd3b 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -83,7 +83,7 @@ use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr, }; use datafusion::plugin::plugin_manager::global_plugin_manager; -use datafusion::plugin::udf::UDFPluginManager; +use datafusion::plugin::udf::{get_udf_plugin_manager, UDFPluginManager}; use datafusion::plugin::PluginEnum; use datafusion::prelude::CsvReadOptions; use log::debug; @@ -320,10 +320,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ExprType::AggregateUdfExpr(agg_node) => { let name = agg_node.fun_name.as_str(); let udaf_fun_name = &name[0..name.find('(').unwrap()]; - let gpm = global_plugin_manager("").lock().unwrap(); - let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); - if let Some(udf_plugin_manager) = plugin_registrar.as_any().downcast_ref::() - { + if let Some(udf_plugin_manager) = get_udf_plugin_manager("") { let fun = udf_plugin_manager.aggregate_udfs.get(udaf_fun_name).ok_or_else(|| { proto_error(format!( "can not get udaf:{} from plugins!", @@ -584,11 +581,7 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { } // argo engine add. ExprType::ScalarUdfProtoExpr(e) => { - let gpm = global_plugin_manager("").lock().unwrap(); - let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap(); - if let Some(udf_plugin_manager) = - plugin_registrar.as_any().downcast_ref::() - { + if let Some(udf_plugin_manager) = get_udf_plugin_manager("") { let fun = udf_plugin_manager .scalar_udfs .get(&e.fun_name) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 6e542c7c68043..5e4e3ef6b711c 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -83,9 +83,7 @@ use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::PhysicalPlanner; -use crate::plugin::plugin_manager::global_plugin_manager; -use crate::plugin::udf::UDFPluginManager; -use crate::plugin::PluginEnum; +use crate::plugin::udf::get_udf_plugin_manager; use crate::sql::{ parser::{DFParser, FileType}, planner::{ContextProvider, SqlToRel}, @@ -198,13 +196,9 @@ impl ExecutionContext { })), }; - let gpm = global_plugin_manager(config.plugin_dir.as_str()); - // register udf - let gpm_guard = gpm.lock().unwrap(); - let plugin_registrar = gpm_guard.plugin_managers.get(&PluginEnum::UDF).unwrap(); if let Some(udf_plugin_manager) = - plugin_registrar.as_any().downcast_ref::() + get_udf_plugin_manager(config.plugin_dir.as_str()) { udf_plugin_manager .scalar_udfs diff --git a/datafusion/src/plugin/mod.rs b/datafusion/src/plugin/mod.rs index 67d6655a2b07c..1450749a2afc4 100644 --- a/datafusion/src/plugin/mod.rs +++ b/datafusion/src/plugin/mod.rs @@ -1,7 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use crate::error::Result; use crate::plugin::udf::UDFPluginManager; +use libloading::Library; use std::any::Any; use std::env; +use std::sync::Arc; /// plugin manager pub mod plugin_manager; @@ -47,17 +66,15 @@ pub struct PluginDeclaration { /// One of PluginEnum pub plugin_type: unsafe extern "C" fn() -> PluginEnum, - - /// `register` is a function which impl PluginRegistrar. It will be call when plugin load. - pub register: unsafe extern "C" fn(&mut Box), } /// Plugin Registrar , Every plugin need implement this trait pub trait PluginRegistrar: Send + Sync + 'static { - /// The implementer of the plug-in needs to call this interface to report his own information to the plug-in manager - fn register_plugin(&mut self, plugin: Box) -> Result<()>; + /// # Safety + /// load plugin from library + unsafe fn load(&mut self, library: Arc) -> Result<()>; - /// Returns the plugin registrar as [`Any`](std::any::Any) so that it can be + /// Returns the plugin as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; } @@ -66,22 +83,12 @@ pub trait PluginRegistrar: Send + Sync + 'static { /// /// # Notes /// -/// This works by automatically generating an `extern "C"` function with a +/// This works by automatically generating an `extern "C"` function named `get_plugin_type` with a /// pre-defined signature and symbol name. And then generating a PluginDeclaration. /// Therefore you will only be able to declare one plugin per library. #[macro_export] macro_rules! declare_plugin { - ($plugin_type:expr, $curr_plugin_type:ty, $constructor:path) => { - #[no_mangle] - pub extern "C" fn register_plugin( - registrar: &mut Box, - ) { - // make sure the constructor is the correct type. - let constructor: fn() -> $curr_plugin_type = $constructor; - let object = constructor(); - registrar.register_plugin(Box::new(object)).unwrap(); - } - + ($plugin_type:expr) => { #[no_mangle] pub extern "C" fn get_plugin_type() -> $crate::plugin::PluginEnum { $plugin_type @@ -93,7 +100,6 @@ macro_rules! declare_plugin { rustc_version: $crate::plugin::RUSTC_VERSION, core_version: $crate::plugin::CORE_VERSION, plugin_type: get_plugin_type, - register: register_plugin, }; }; } diff --git a/datafusion/src/plugin/plugin_manager.rs b/datafusion/src/plugin/plugin_manager.rs index a8a19e4ac8d9a..3ffd15175d348 100644 --- a/datafusion/src/plugin/plugin_manager.rs +++ b/datafusion/src/plugin/plugin_manager.rs @@ -1,3 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. use crate::error::{DataFusionError, Result}; use crate::plugin::{PluginDeclaration, CORE_VERSION, RUSTC_VERSION}; use crate::plugin::{PluginEnum, PluginRegistrar}; @@ -13,11 +29,14 @@ use once_cell::sync::OnceCell; /// To prevent the library from being loaded multiple times, we use once_cell defines a Arc> /// Because datafusion is a library, not a service, users may not need to load all plug-ins in the process. /// So fn global_plugin_manager return Arc>. In this way, users can load the required library through the load method of GlobalPluginManager when needed +static INSTANCE: OnceCell>> = OnceCell::new(); + +/// global_plugin_manager pub fn global_plugin_manager( plugin_path: &str, ) -> &'static Arc> { - static INSTANCE: OnceCell>> = OnceCell::new(); INSTANCE.get_or_init(move || unsafe { + println!("====================init==================="); let mut gpm = GlobalPluginManager::default(); gpm.load(plugin_path).unwrap(); Arc::new(Mutex::new(gpm)) @@ -38,6 +57,9 @@ impl GlobalPluginManager { /// # Safety /// find plugin file from `plugin_path` and load it . unsafe fn load(&mut self, plugin_path: &str) -> Result<()> { + if "".eq(plugin_path) { + return Ok(()); + } // find library file from udaf_plugin_path info!("load plugin from dir:{}", plugin_path); println!("load plugin from dir:{}", plugin_path); @@ -54,18 +76,18 @@ impl GlobalPluginManager { let library = Arc::new(library); - // get a pointer to the plugin_declaration symbol. - let dec = library - .get::<*mut PluginDeclaration>(b"plugin_declaration\0") - .map_err(|e| { - DataFusionError::IoError(io::Error::new( - io::ErrorKind::Other, - format!("not found plugin_declaration in the library: {}", e), - )) - })? - .read(); - - // version checks to prevent accidental ABI incompatibilities + let dec = library.get::<*mut PluginDeclaration>(b"plugin_declaration\0"); + if dec.is_err() { + info!( + "not found plugin_declaration in the library: {}", + plugin_file.path().to_str().unwrap() + ); + return Ok(()); + } + + let dec = dec.unwrap().read(); + + // ersion checks to prevent accidental ABI incompatibilities if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION { return Err(DataFusionError::IoError(io::Error::new( io::ErrorKind::Other, @@ -82,8 +104,7 @@ impl GlobalPluginManager { } Some(manager) => manager, }; - - (dec.register)(curr_plugin_manager); + curr_plugin_manager.load(library)?; self.plugin_files .push(plugin_file.path().to_str().unwrap().to_string()); } @@ -112,17 +133,20 @@ impl GlobalPluginManager { if let Some(path) = item.path().extension() { if let Some(suffix) = path.to_str() { if suffix == "dylib" || suffix == "so" || suffix == "dll" { - info!("load plugin from library file:{}", path.to_str().unwrap()); + info!( + "load plugin from library file:{}", + item.path().to_str().unwrap() + ); println!( "load plugin from library file:{}", - path.to_str().unwrap() + item.path().to_str().unwrap() ); return Some(item); } } } - return None; + None }) { plugin_files.push(entry); } diff --git a/datafusion/src/plugin/udf.rs b/datafusion/src/plugin/udf.rs index ffbb928fbd0f3..7f223ee69570f 100644 --- a/datafusion/src/plugin/udf.rs +++ b/datafusion/src/plugin/udf.rs @@ -1,8 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. use crate::error::{DataFusionError, Result}; use crate::physical_plan::udaf::AggregateUDF; use crate::physical_plan::udf::ScalarUDF; -use crate::plugin::{Plugin, PluginRegistrar}; -use libloading::Library; +use crate::plugin::plugin_manager::global_plugin_manager; +use crate::plugin::{Plugin, PluginEnum, PluginRegistrar}; +use libloading::{Library, Symbol}; use std::any::Any; use std::collections::HashMap; use std::io; @@ -24,7 +41,7 @@ pub trait UDFPlugin: Plugin { } /// UDFPluginManager -#[derive(Default)] +#[derive(Default, Clone)] pub struct UDFPluginManager { /// scalar udfs pub scalar_udfs: HashMap>, @@ -37,52 +54,98 @@ pub struct UDFPluginManager { } impl PluginRegistrar for UDFPluginManager { - fn register_plugin(&mut self, plugin: Box) -> Result<()> { - if let Some(udf_plugin) = plugin.as_any().downcast_ref::>() { - udf_plugin - .udf_names() - .unwrap() - .iter() - .try_for_each(|udf_name| { - if self.scalar_udfs.contains_key(udf_name) { - Err(DataFusionError::IoError(io::Error::new( - io::ErrorKind::Other, - format!("udf name: {} already exists", udf_name), - ))) - } else { - let scalar_udf = udf_plugin.get_scalar_udf_by_name(udf_name)?; - self.scalar_udfs - .insert(udf_name.to_string(), Arc::new(scalar_udf)); - Ok(()) - } - })?; + unsafe fn load(&mut self, library: Arc) -> Result<()> { + type PluginRegister = unsafe fn() -> Box; + let register_fun: Symbol = + library.get(b"registrar_udf_plugin\0").map_err(|e| { + DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("not found fn registrar_udf_plugin in the library: {}", e), + )) + })?; - udf_plugin - .udaf_names() - .unwrap() - .iter() - .try_for_each(|udaf_name| { - if self.aggregate_udfs.contains_key(udaf_name) { - Err(DataFusionError::IoError(io::Error::new( - io::ErrorKind::Other, - format!("udaf name: {} already exists", udaf_name), - ))) - } else { - let aggregate_udf = - udf_plugin.get_aggregate_udf_by_name(udaf_name)?; - self.aggregate_udfs - .insert(udaf_name.to_string(), Arc::new(aggregate_udf)); - Ok(()) - } - })?; - } - Err(DataFusionError::IoError(io::Error::new( - io::ErrorKind::Other, - format!("expected plugin type is 'dyn UDFPlugin', but it's not"), - ))) + let udf_plugin: Box = register_fun(); + udf_plugin + .udf_names() + .unwrap() + .iter() + .try_for_each(|udf_name| { + if self.scalar_udfs.contains_key(udf_name) { + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udf name: {} already exists", udf_name), + ))) + } else { + let scalar_udf = udf_plugin.get_scalar_udf_by_name(udf_name)?; + self.scalar_udfs + .insert(udf_name.to_string(), Arc::new(scalar_udf)); + Ok(()) + } + })?; + + udf_plugin + .udaf_names() + .unwrap() + .iter() + .try_for_each(|udaf_name| { + if self.aggregate_udfs.contains_key(udaf_name) { + Err(DataFusionError::IoError(io::Error::new( + io::ErrorKind::Other, + format!("udaf name: {} already exists", udaf_name), + ))) + } else { + let aggregate_udf = + udf_plugin.get_aggregate_udf_by_name(udaf_name)?; + self.aggregate_udfs + .insert(udaf_name.to_string(), Arc::new(aggregate_udf)); + Ok(()) + } + })?; + Ok(()) } fn as_any(&self) -> &dyn Any { self } } + +/// Declare a udf plugin registrar callback +/// +/// # Notes +/// +/// This works by automatically generating an `extern "C"` function named `registrar_udf_plugin` with a +/// pre-defined signature and symbol name. +/// Therefore you will only be able to declare one plugin per library. +#[macro_export] +macro_rules! declare_udf_plugin { + ($curr_plugin_type:ty, $constructor:path) => { + #[no_mangle] + pub extern "C" fn registrar_udf_plugin() -> Box { + // make sure the constructor is the correct type. + let constructor: fn() -> $curr_plugin_type = $constructor; + let object = constructor(); + Box::new(object) + } + + $crate::declare_plugin!($crate::plugin::PluginEnum::UDF); + }; +} + +/// get a Option of Immutable UDFPluginManager +pub fn get_udf_plugin_manager(path: &str) -> Option { + let udf_plugin_manager_opt = { + let gpm = global_plugin_manager(path).lock().unwrap(); + let plugin_registrar_opt = gpm.plugin_managers.get(&PluginEnum::UDF); + if let Some(plugin_registrar) = plugin_registrar_opt { + if let Some(udf_plugin_manager) = + plugin_registrar.as_any().downcast_ref::() + { + return Some(udf_plugin_manager.clone()); + } else { + return None; + } + } + None + }; + udf_plugin_manager_opt +}