diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..272dfd7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,85 @@ +name: CI +on: + schedule: [{cron: "30 13 * * *"}] + push: + branches: + - master + pull_request: + +jobs: + format: + name: Rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + components: rustfmt + override: true + - run: cargo fmt -- --check + + lint: + name: Clippy lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rs/toolchain@v1 + with: + toolchain: nightly + components: clippy + override: true + - name: Lint + run: cargo clippy --all-features -- -D warnings + + test: + name: Test + runs-on: ${{ matrix.os }} + strategy: + matrix: + build: [stable, nightly] + include: + - build: stable + os: ubuntu-latest + rust: stable + - build: nightly + os: ubuntu-latest + rust: nightly + steps: + - uses: actions/checkout@v4 + - uses: actions-rs/toolchain@v1 + with: + toolchain: ${{ matrix.rust }} + override: true + - name: Build and run tests + run: cargo test --all-features + + # integration: + # name: Integration tests + # runs-on: ubuntu-latest + # timeout-minutes: 15 + # needs: [test] + # steps: + # - uses: actions/checkout@v4 + # - uses: actions-rs/toolchain@v1 + # with: + # toolchain: stable + # override: true + # - run: | + # pip install psycopg + # pip install psycopg2 + # - uses: turtlequeue/setup-babashka@v1.5.0 + # with: + # babashka-version: 1.1.173 + # - run: ./tests-integration/test.sh + + msrv: + name: MSRV + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions-rs/toolchain@v1 + with: + toolchain: "1.72" + override: true + - run: cargo build --all-features diff --git a/src/datatypes.rs b/src/datatypes.rs index 69925df..f42528a 100644 --- a/src/datatypes.rs +++ b/src/datatypes.rs @@ -6,8 +6,11 @@ use datafusion::arrow::datatypes::{ UInt32Type, UInt64Type, UInt8Type, }; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::common::{DFSchema, ParamValues}; use datafusion::prelude::*; +use datafusion::scalar::ScalarValue; use futures::{stream, StreamExt}; +use pgwire::api::portal::Portal; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse}; use pgwire::api::Type; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; @@ -214,24 +217,25 @@ fn encode_value( Ok(()) } +pub(crate) fn df_schema_to_pg_fields(schema: &DFSchema) -> PgWireResult> { + schema + .fields() + .iter() + .map(|f| { + let pg_type = into_pg_type(f.data_type())?; + Ok(FieldInfo::new( + f.name().into(), + None, + None, + pg_type, + FieldFormat::Text, + )) + }) + .collect::>>() +} + pub(crate) async fn encode_dataframe<'a>(df: DataFrame) -> PgWireResult> { - let schema = df.schema(); - let fields = Arc::new( - schema - .fields() - .iter() - .map(|f| { - let pg_type = into_pg_type(f.data_type())?; - Ok(FieldInfo::new( - f.name().into(), - None, - None, - pg_type, - FieldFormat::Text, - )) - }) - .collect::>>()?, - ); + let fields = Arc::new(df_schema_to_pg_fields(df.schema())?); let recordbatch_stream = df .execute_stream() @@ -266,3 +270,86 @@ pub(crate) async fn encode_dataframe<'a>(df: DataFrame) -> PgWireResult( + portal: &Portal, + inferenced_types: &[Option<&DataType>], +) -> PgWireResult +where + S: Clone, +{ + fn get_pg_type( + pg_type_hint: Option<&Type>, + inferenced_type: Option<&DataType>, + ) -> PgWireResult { + if let Some(ty) = pg_type_hint { + Ok(ty.clone()) + } else if let Some(infer_type) = inferenced_type { + into_pg_type(infer_type) + } else { + Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "FATAL".to_string(), + "XX000".to_string(), + "Unknown parameter type".to_string(), + )))) + } + } + + let param_len = portal.parameter_len(); + let mut deserialized_params = Vec::with_capacity(param_len); + for i in 0..param_len { + let pg_type = get_pg_type( + portal.statement.parameter_types.get(i), + inferenced_types.get(i).and_then(|v| v.to_owned()), + )?; + match pg_type { + // enumerate all supported parameter types and deserialize the + // type to ScalarValue + Type::BOOL => { + let value = portal.parameter::(i, &pg_type)?; + deserialized_params.push(ScalarValue::Boolean(value)); + } + Type::INT2 => { + let value = portal.parameter::(i, &pg_type)?; + deserialized_params.push(ScalarValue::Int16(value)); + } + Type::INT4 => { + let value = portal.parameter::(i, &pg_type)?; + deserialized_params.push(ScalarValue::Int32(value)); + } + Type::INT8 => { + let value = portal.parameter::(i, &pg_type)?; + deserialized_params.push(ScalarValue::Int64(value)); + } + Type::TEXT | Type::VARCHAR => { + let value = portal.parameter::(i, &pg_type)?; + deserialized_params.push(ScalarValue::Utf8(value)); + } + Type::FLOAT4 => { + let value = portal.parameter::(i, &pg_type)?; + deserialized_params.push(ScalarValue::Float32(value)); + } + Type::FLOAT8 => { + let value = portal.parameter::(i, &pg_type)?; + deserialized_params.push(ScalarValue::Float64(value)); + } + // TODO: add more types like Timestamp, Datetime, Bytea + _ => { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "FATAL".to_string(), + "XX000".to_string(), + format!("Unsupported parameter type: {}", pg_type), + )))); + } + } + } + + Ok(ParamValues::List(deserialized_params)) +} diff --git a/src/handlers.rs b/src/handlers.rs index 5d6dade..dab019d 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,23 +1,35 @@ use std::sync::Arc; use async_trait::async_trait; +use datafusion::arrow::datatypes::DataType; +use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::*; -use pgwire::api::query::SimpleQueryHandler; -use pgwire::api::results::{Response, Tag}; -use pgwire::api::ClientInfo; +use pgwire::api::portal::Portal; +use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler}; +use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response, Tag}; +use pgwire::api::stmt::QueryParser; +use pgwire::api::stmt::StoredStatement; +use pgwire::api::{ClientInfo, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; + use tokio::sync::Mutex; -use crate::datatypes::encode_dataframe; +use crate::datatypes::{self, into_pg_type}; pub(crate) struct DfSessionService { session_context: Arc>, + parser: Arc, } impl DfSessionService { pub fn new() -> DfSessionService { + let session_context = Arc::new(Mutex::new(SessionContext::new())); + let parser = Arc::new(Parser { + session_context: session_context.clone(), + }); DfSessionService { - session_context: Arc::new(Mutex::new(SessionContext::new())), + session_context, + parser, } } } @@ -50,7 +62,7 @@ impl SimpleQueryHandler for DfSessionService { .await .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - let resp = encode_dataframe(df).await?; + let resp = datatypes::encode_dataframe(df).await?; Ok(vec![Response::Query(resp)]) } else { Ok(vec![Response::Error(Box::new(ErrorInfo::new( @@ -61,3 +73,120 @@ impl SimpleQueryHandler for DfSessionService { } } } + +pub(crate) struct Parser { + session_context: Arc>, +} + +#[async_trait] +impl QueryParser for Parser { + type Statement = LogicalPlan; + + async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult { + let context = self.session_context.lock().await; + let state = context.state(); + + let logical_plan = state + .create_logical_plan(sql) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + let optimised = state + .optimize(&logical_plan) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + Ok(optimised) + } +} + +#[async_trait] +impl ExtendedQueryHandler for DfSessionService { + type Statement = LogicalPlan; + + type QueryParser = Parser; + + fn query_parser(&self) -> Arc { + self.parser.clone() + } + + async fn do_describe_statement( + &self, + _client: &mut C, + target: &StoredStatement, + ) -> PgWireResult + where + C: ClientInfo + Unpin + Send + Sync, + { + let plan = &target.statement; + + let schema = plan.schema(); + let fields = datatypes::df_schema_to_pg_fields(schema.as_ref())?; + let params = plan + .get_parameter_types() + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + dbg!(¶ms); + let mut param_types = Vec::with_capacity(params.len()); + for param_type in params.into_values() { + if let Some(datatype) = param_type { + let pgtype = into_pg_type(&datatype)?; + param_types.push(pgtype); + } else { + param_types.push(Type::UNKNOWN); + } + } + + Ok(DescribeStatementResponse::new(param_types, fields)) + } + + async fn do_describe_portal( + &self, + _client: &mut C, + target: &Portal, + ) -> PgWireResult + where + C: ClientInfo + Unpin + Send + Sync, + { + let plan = &target.statement.statement; + let schema = plan.schema(); + let fields = datatypes::df_schema_to_pg_fields(schema.as_ref())?; + + Ok(DescribePortalResponse::new(fields)) + } + + async fn do_query<'a, C>( + &self, + _client: &mut C, + portal: &'a Portal, + _max_rows: usize, + ) -> PgWireResult> + where + C: ClientInfo + Unpin + Send + Sync, + { + let plan = &portal.statement.statement; + + let param_values = datatypes::deserialize_parameters( + portal, + &plan + .get_parameter_types() + .map_err(|e| PgWireError::ApiError(Box::new(e)))? + .values() + .map(|v| v.as_ref()) + .collect::>>(), + )?; + + let plan = plan + .replace_params_with_values(¶m_values) + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + let dataframe = self + .session_context + .lock() + .await + .execute_logical_plan(plan) + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + let resp = datatypes::encode_dataframe(dataframe).await?; + Ok(Response::Query(resp)) + } +} diff --git a/src/main.rs b/src/main.rs index 8437490..4b7eed0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use pgwire::api::auth::noop::NoopStartupHandler; -use pgwire::api::query::PlaceholderExtendedQueryHandler; use pgwire::api::{MakeHandler, StatelessMakeHandler}; use pgwire::tokio::process_socket; use tokio::net::TcpListener; @@ -14,10 +13,6 @@ async fn main() { let processor = Arc::new(StatelessMakeHandler::new(Arc::new( handlers::DfSessionService::new(), ))); - // We have not implemented extended query in this server, use placeholder instead - let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new( - PlaceholderExtendedQueryHandler, - ))); let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); let server_addr = "127.0.0.1:5432"; @@ -28,14 +23,13 @@ async fn main() { let incoming_socket = listener.accept().await.unwrap(); let authenticator_ref = authenticator.make(); let processor_ref = processor.make(); - let placeholder_ref = placeholder.make(); tokio::spawn(async move { process_socket( incoming_socket.0, None, authenticator_ref, + processor_ref.clone(), processor_ref, - placeholder_ref, ) .await });