Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
name: CI
on:
schedule: [{cron: "30 13 * * *"}]
push:
branches:
- master
pull_request:

jobs:
format:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly
components: rustfmt
override: true
- run: cargo fmt -- --check

lint:
name: Clippy lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly
components: clippy
override: true
- name: Lint
run: cargo clippy --all-features -- -D warnings

test:
name: Test
runs-on: ${{ matrix.os }}
strategy:
matrix:
build: [stable, nightly]
include:
- build: stable
os: ubuntu-latest
rust: stable
- build: nightly
os: ubuntu-latest
rust: nightly
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: ${{ matrix.rust }}
override: true
- name: Build and run tests
run: cargo test --all-features

# integration:
# name: Integration tests
# runs-on: ubuntu-latest
# timeout-minutes: 15
# needs: [test]
# steps:
# - uses: actions/checkout@v4
# - uses: actions-rs/toolchain@v1
# with:
# toolchain: stable
# override: true
# - run: |
# pip install psycopg
# pip install psycopg2
# - uses: turtlequeue/[email protected]
# with:
# babashka-version: 1.1.173
# - run: ./tests-integration/test.sh

msrv:
name: MSRV
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: "1.72"
override: true
- run: cargo build --all-features
121 changes: 104 additions & 17 deletions src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ use datafusion::arrow::datatypes::{
UInt32Type, UInt64Type, UInt8Type,
};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::{DFSchema, ParamValues};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use futures::{stream, StreamExt};
use pgwire::api::portal::Portal;
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse};
use pgwire::api::Type;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
Expand Down Expand Up @@ -214,24 +217,25 @@ fn encode_value(
Ok(())
}

pub(crate) fn df_schema_to_pg_fields(schema: &DFSchema) -> PgWireResult<Vec<FieldInfo>> {
schema
.fields()
.iter()
.map(|f| {
let pg_type = into_pg_type(f.data_type())?;
Ok(FieldInfo::new(
f.name().into(),
None,
None,
pg_type,
FieldFormat::Text,
))
})
.collect::<PgWireResult<Vec<FieldInfo>>>()
}

pub(crate) async fn encode_dataframe<'a>(df: DataFrame) -> PgWireResult<QueryResponse<'a>> {
let schema = df.schema();
let fields = Arc::new(
schema
.fields()
.iter()
.map(|f| {
let pg_type = into_pg_type(f.data_type())?;
Ok(FieldInfo::new(
f.name().into(),
None,
None,
pg_type,
FieldFormat::Text,
))
})
.collect::<PgWireResult<Vec<FieldInfo>>>()?,
);
let fields = Arc::new(df_schema_to_pg_fields(df.schema())?);

let recordbatch_stream = df
.execute_stream()
Expand Down Expand Up @@ -266,3 +270,86 @@ pub(crate) async fn encode_dataframe<'a>(df: DataFrame) -> PgWireResult<QueryRes

Ok(QueryResponse::new(fields, pg_row_stream))
}

/// Deserialize client provided parameter data.
///
/// First we try to use the type information from `pg_type_hint`, which is
/// provided by the client.
/// If the type is empty or unknown, we fallback to datafusion inferenced type
/// from `inferenced_types`.
/// An error will be raised when neither sources can provide type information.
pub(crate) fn deserialize_parameters<S>(
portal: &Portal<S>,
inferenced_types: &[Option<&DataType>],
) -> PgWireResult<ParamValues>
where
S: Clone,
{
fn get_pg_type(
pg_type_hint: Option<&Type>,
inferenced_type: Option<&DataType>,
) -> PgWireResult<Type> {
if let Some(ty) = pg_type_hint {
Ok(ty.clone())
} else if let Some(infer_type) = inferenced_type {
into_pg_type(infer_type)
} else {
Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
"XX000".to_string(),
"Unknown parameter type".to_string(),
))))
}
}

let param_len = portal.parameter_len();
let mut deserialized_params = Vec::with_capacity(param_len);
for i in 0..param_len {
let pg_type = get_pg_type(
portal.statement.parameter_types.get(i),
inferenced_types.get(i).and_then(|v| v.to_owned()),
)?;
match pg_type {
// enumerate all supported parameter types and deserialize the
// type to ScalarValue
Type::BOOL => {
let value = portal.parameter::<bool>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Boolean(value));
}
Type::INT2 => {
let value = portal.parameter::<i16>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Int16(value));
}
Type::INT4 => {
let value = portal.parameter::<i32>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Int32(value));
}
Type::INT8 => {
let value = portal.parameter::<i64>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Int64(value));
}
Type::TEXT | Type::VARCHAR => {
let value = portal.parameter::<String>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Utf8(value));
}
Type::FLOAT4 => {
let value = portal.parameter::<f32>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Float32(value));
}
Type::FLOAT8 => {
let value = portal.parameter::<f64>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Float64(value));
}
// TODO: add more types like Timestamp, Datetime, Bytea
_ => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_string(),
"XX000".to_string(),
format!("Unsupported parameter type: {}", pg_type),
))));
}
}
}

Ok(ParamValues::List(deserialized_params))
}
141 changes: 135 additions & 6 deletions src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
use std::sync::Arc;

use async_trait::async_trait;
use datafusion::arrow::datatypes::DataType;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::*;
use pgwire::api::query::SimpleQueryHandler;
use pgwire::api::results::{Response, Tag};
use pgwire::api::ClientInfo;
use pgwire::api::portal::Portal;
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{DescribePortalResponse, DescribeStatementResponse, Response, Tag};
use pgwire::api::stmt::QueryParser;
use pgwire::api::stmt::StoredStatement;
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};

use tokio::sync::Mutex;

use crate::datatypes::encode_dataframe;
use crate::datatypes::{self, into_pg_type};

pub(crate) struct DfSessionService {
session_context: Arc<Mutex<SessionContext>>,
parser: Arc<Parser>,
}

impl DfSessionService {
pub fn new() -> DfSessionService {
let session_context = Arc::new(Mutex::new(SessionContext::new()));
let parser = Arc::new(Parser {
session_context: session_context.clone(),
});
DfSessionService {
session_context: Arc::new(Mutex::new(SessionContext::new())),
session_context,
parser,
}
}
}
Expand Down Expand Up @@ -50,7 +62,7 @@ impl SimpleQueryHandler for DfSessionService {
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let resp = encode_dataframe(df).await?;
let resp = datatypes::encode_dataframe(df).await?;
Ok(vec![Response::Query(resp)])
} else {
Ok(vec![Response::Error(Box::new(ErrorInfo::new(
Expand All @@ -61,3 +73,120 @@ impl SimpleQueryHandler for DfSessionService {
}
}
}

pub(crate) struct Parser {
session_context: Arc<Mutex<SessionContext>>,
}

#[async_trait]
impl QueryParser for Parser {
type Statement = LogicalPlan;

async fn parse_sql(&self, sql: &str, _types: &[Type]) -> PgWireResult<Self::Statement> {
let context = self.session_context.lock().await;
let state = context.state();

let logical_plan = state
.create_logical_plan(sql)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let optimised = state
.optimize(&logical_plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

Ok(optimised)
}
}

#[async_trait]
impl ExtendedQueryHandler for DfSessionService {
type Statement = LogicalPlan;

type QueryParser = Parser;

fn query_parser(&self) -> Arc<Self::QueryParser> {
self.parser.clone()
}

async fn do_describe_statement<C>(
&self,
_client: &mut C,
target: &StoredStatement<Self::Statement>,
) -> PgWireResult<DescribeStatementResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
let plan = &target.statement;

let schema = plan.schema();
let fields = datatypes::df_schema_to_pg_fields(schema.as_ref())?;
let params = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

dbg!(&params);
let mut param_types = Vec::with_capacity(params.len());
for param_type in params.into_values() {
if let Some(datatype) = param_type {
let pgtype = into_pg_type(&datatype)?;
param_types.push(pgtype);
} else {
param_types.push(Type::UNKNOWN);
}
}

Ok(DescribeStatementResponse::new(param_types, fields))
}

async fn do_describe_portal<C>(
&self,
_client: &mut C,
target: &Portal<Self::Statement>,
) -> PgWireResult<DescribePortalResponse>
where
C: ClientInfo + Unpin + Send + Sync,
{
let plan = &target.statement.statement;
let schema = plan.schema();
let fields = datatypes::df_schema_to_pg_fields(schema.as_ref())?;

Ok(DescribePortalResponse::new(fields))
}

async fn do_query<'a, C>(
&self,
_client: &mut C,
portal: &'a Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response<'a>>
where
C: ClientInfo + Unpin + Send + Sync,
{
let plan = &portal.statement.statement;

let param_values = datatypes::deserialize_parameters(
portal,
&plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
.values()
.map(|v| v.as_ref())
.collect::<Vec<Option<&DataType>>>(),
)?;

let plan = plan
.replace_params_with_values(&param_values)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let dataframe = self
.session_context
.lock()
.await
.execute_logical_plan(plan)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let resp = datatypes::encode_dataframe(dataframe).await?;
Ok(Response::Query(resp))
}
}
Loading