diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 54d3e9e47..97d36b7cb 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -10,6 +10,7 @@ use crate::logical_plan::analyze::model_generation::ModelGenerationRule; use crate::logical_plan::optimize::simplify_timestamp::TimestampSimplify; use crate::logical_plan::utils::create_schema; use crate::mdl::manifest::Model; +use crate::mdl::type_planner::WrenTypePlanner; use crate::mdl::{AnalyzedWrenMDL, SessionStateRef}; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; @@ -74,9 +75,11 @@ pub async fn create_ctx_with_mdl( .set("datafusion.execution.time_zone", &session_timezone)?; } + let type_planner = Arc::new(WrenTypePlanner::default()); let reset_default_catalog_schema = Arc::new(RwLock::new( SessionStateBuilder::new_from_existing(ctx.state()) .with_config(config.clone()) + .with_type_planner(type_planner) .build(), )); diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 5c4e9dc45..2795fed00 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -43,6 +43,7 @@ pub mod lineage; pub mod manifest { pub use wren_core_base::mdl::manifest::*; } +pub mod type_planner; pub mod utils; pub type SessionStateRef = Arc>; @@ -549,7 +550,7 @@ mod test { use std::sync::Arc; use crate::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder}; - use crate::mdl::context::{create_ctx_with_mdl, Mode}; + use crate::mdl::context::{create_ctx_with_mdl, Mode, SessionPropertiesRef}; use crate::mdl::function::RemoteFunction; use crate::mdl::manifest::DataSource::MySQL; use crate::mdl::manifest::Manifest; @@ -3525,6 +3526,25 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_compatible_type() -> Result<()> { + let ctx = SessionContext::new(); + + let manifest = ManifestBuilder::default().build(); + let properties = SessionPropertiesRef::default(); + let mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::clone(&properties), + Mode::Unparse, + )?); + let sql = "select cast(1 as int64)"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], Arc::clone(&properties), sql).await?, + @"SELECT CAST(1 AS BIGINT)" + ); + Ok(()) + } + /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3])); diff --git a/wren-core/core/src/mdl/type_planner.rs b/wren-core/core/src/mdl/type_planner.rs new file mode 100644 index 000000000..1165a4395 --- /dev/null +++ b/wren-core/core/src/mdl/type_planner.rs @@ -0,0 +1,33 @@ +use datafusion::{ + arrow::datatypes::{DataType, TimeUnit}, + error::Result, + logical_expr::planner::TypePlanner, + sql::sqlparser::ast::DataType as SQLDataType, +}; + +#[derive(Debug, Clone, Default)] +pub struct WrenTypePlanner {} + +impl TypePlanner for WrenTypePlanner { + fn plan_type(&self, sql_type: &SQLDataType) -> Result> { + match sql_type { + SQLDataType::Int64 => Ok(Some(DataType::Int64)), + SQLDataType::Int32 => Ok(Some(DataType::Int32)), + SQLDataType::Float32 => Ok(Some(DataType::Float32)), + SQLDataType::Float64 => Ok(Some(DataType::Float64)), + SQLDataType::Datetime(precision) + if precision.is_none() || [0, 3, 6, 9].contains(&precision.unwrap()) => + { + let precision = match precision { + Some(0) => TimeUnit::Second, + Some(3) => TimeUnit::Millisecond, + Some(6) => TimeUnit::Microsecond, + None | Some(9) => TimeUnit::Nanosecond, + _ => unreachable!(), + }; + Ok(Some(DataType::Timestamp(precision, None))) + } + _ => Ok(None), + } + } +} diff --git a/wren-core/sqllogictest/src/test_context.rs b/wren-core/sqllogictest/src/test_context.rs index fa80c6656..707e9f976 100644 --- a/wren-core/sqllogictest/src/test_context.rs +++ b/wren-core/sqllogictest/src/test_context.rs @@ -28,7 +28,7 @@ use tempfile::TempDir; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, ViewBuilder, }; -use wren_core::mdl::context::{create_ctx_with_mdl, Mode}; +use wren_core::mdl::context::{create_ctx_with_mdl, Mode, SessionPropertiesRef}; use wren_core::mdl::manifest::JoinType; use wren_core::mdl::AnalyzedWrenMDL; @@ -77,7 +77,17 @@ impl TestContext { } _ => { info!("Using default SessionContext"); - None + let mdl = Arc::new(AnalyzedWrenMDL::default()); + let ctx = create_ctx_with_mdl( + &ctx, + mdl.clone(), + SessionPropertiesRef::default(), + Mode::LocalRuntime, + ) + .await + .ok() + .unwrap(); + Some(TestContext::new(ctx, mdl)) } } } diff --git a/wren-core/sqllogictest/test_files/type.slt b/wren-core/sqllogictest/test_files/type.slt new file mode 100644 index 000000000..f59338dfe --- /dev/null +++ b/wren-core/sqllogictest/test_files/type.slt @@ -0,0 +1,19 @@ +query II +select cast(1 as int64) as c1, cast(1 as int32) as c2 +---- +1 1 + +query RR +select cast(1.0 as float64) as c1, cast(1.0 as float32) as c2 +---- +1 1 + + +query PPPP +select + cast('2000-01-01 10:00:00' as datetime), + cast('2000-01-01 10:00:00.123' as datetime), + cast('2000-01-01 10:00:00.123456' as datetime), + cast('2000-01-01 10:00:00.123456789' as datetime) +---- +2000-01-01T10:00:00 2000-01-01T10:00:00.123 2000-01-01T10:00:00.123456 2000-01-01T10:00:00.123456789