Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 160 additions & 141 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use pgwire::api::auth::StartupHandler;
use pgwire::api::portal::{Format, Portal};
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{
DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
Response, Tag,
DescribePortalResponse, DescribeResponse, DescribeStatementResponse, FieldFormat, FieldInfo,
QueryResponse, Response, Tag,
};
use pgwire::api::stmt::QueryParser;
use pgwire::api::stmt::StoredStatement;
Expand Down Expand Up @@ -438,97 +438,103 @@ impl SimpleQueryHandler for DfSessionService {
return Ok(vec![resp]);
}

let mut statements = self
let statements = self
.parser
.sql_parser
.parse(query)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

// TODO: deal with multiple statements
let statement = statements.remove(0);

// TODO: improve statement check by using statement directly
let query = statement.to_string();
let query_lower = query.to_lowercase().trim().to_string();

// Check permissions for the query (skip for SET, transaction, and SHOW statements)
if !query_lower.starts_with("set")
&& !query_lower.starts_with("begin")
&& !query_lower.starts_with("commit")
&& !query_lower.starts_with("rollback")
&& !query_lower.starts_with("start")
&& !query_lower.starts_with("end")
&& !query_lower.starts_with("abort")
&& !query_lower.starts_with("show")
{
self.check_query_permission(client, &query).await?;
}
// empty query
if statements.is_empty() {
return Ok(vec![Response::EmptyQuery]);
}

let mut results = vec![];
for statement in statements {
// TODO: improve statement check by using statement directly
let query = statement.to_string();
let query_lower = query.to_lowercase().trim().to_string();

// Check permissions for the query (skip for SET, transaction, and SHOW statements)
if !query_lower.starts_with("set")
&& !query_lower.starts_with("begin")
&& !query_lower.starts_with("commit")
&& !query_lower.starts_with("rollback")
&& !query_lower.starts_with("start")
&& !query_lower.starts_with("end")
&& !query_lower.starts_with("abort")
&& !query_lower.starts_with("show")
{
self.check_query_permission(client, &query).await?;
}

if let Some(resp) = self
.try_respond_set_statements(client, &query_lower)
.await?
{
return Ok(vec![resp]);
}
if let Some(resp) = self
.try_respond_set_statements(client, &query_lower)
.await?
{
return Ok(vec![resp]);
}

if let Some(resp) = self
.try_respond_show_statements(client, &query_lower)
.await?
{
return Ok(vec![resp]);
}
if let Some(resp) = self
.try_respond_show_statements(client, &query_lower)
.await?
{
return Ok(vec![resp]);
}

// Check if we're in a failed transaction and block non-transaction
// commands
if client.transaction_status() == TransactionStatus::Error {
return Err(PgWireError::UserError(Box::new(
// Check if we're in a failed transaction and block non-transaction
// commands
if client.transaction_status() == TransactionStatus::Error {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"25P01".to_string(),
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
),
)));
}

let df_result = {
let timeout = Self::get_statement_timeout(client);
if let Some(timeout_duration) = timeout {
tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to statement timeout".to_string(),
)))
})?
} else {
self.session_context.sql(&query).await
}
};

// Handle query execution errors and transaction state
let df = match df_result {
Ok(df) => df,
Err(e) => {
return Err(PgWireError::ApiError(Box::new(e)));
}
};
let df_result = {
let timeout = Self::get_statement_timeout(client);
if let Some(timeout_duration) = timeout {
tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to statement timeout".to_string(),
)))
})?
} else {
self.session_context.sql(&query).await
}
};

if query_lower.starts_with("insert into") {
let resp = map_rows_affected_for_insert(&df).await?;
Ok(vec![resp])
} else {
// For non-INSERT queries, return a regular Query response
let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
Ok(vec![Response::Query(resp)])
// Handle query execution errors and transaction state
let df = match df_result {
Ok(df) => df,
Err(e) => {
return Err(PgWireError::ApiError(Box::new(e)));
}
};

if query_lower.starts_with("insert into") {
let resp = map_rows_affected_for_insert(&df).await?;
results.push(resp);
} else {
// For non-INSERT queries, return a regular Query response
let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
results.push(Response::Query(resp));
}
}
Ok(results)
}
}

#[async_trait]
impl ExtendedQueryHandler for DfSessionService {
type Statement = (String, LogicalPlan);
type Statement = (String, Option<LogicalPlan>);
type QueryParser = Parser;

fn query_parser(&self) -> Arc<Self::QueryParser> {
Expand All @@ -543,25 +549,28 @@ impl ExtendedQueryHandler for DfSessionService {
where
C: ClientInfo + Unpin + Send + Sync,
{
let (_, plan) = &target.statement;
let schema = plan.schema();
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
let params = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let mut param_types = Vec::with_capacity(params.len());
for param_type in ordered_param_types(&params).iter() {
// Fixed: Use &params
if let Some(datatype) = param_type {
let pgtype = into_pg_type(datatype)?;
param_types.push(pgtype);
} else {
param_types.push(Type::UNKNOWN);
if let (_, Some(plan)) = &target.statement {
let schema = plan.schema();
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
let params = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let mut param_types = Vec::with_capacity(params.len());
for param_type in ordered_param_types(&params).iter() {
// Fixed: Use &params
if let Some(datatype) = param_type {
let pgtype = into_pg_type(datatype)?;
param_types.push(pgtype);
} else {
param_types.push(Type::UNKNOWN);
}
}
}

Ok(DescribeStatementResponse::new(param_types, fields))
Ok(DescribeStatementResponse::new(param_types, fields))
} else {
Ok(DescribeStatementResponse::no_data())
}
}

async fn do_describe_portal<C>(
Expand All @@ -572,12 +581,15 @@ impl ExtendedQueryHandler for DfSessionService {
where
C: ClientInfo + Unpin + Send + Sync,
{
let (_, plan) = &target.statement.statement;
let format = &target.result_column_format;
let schema = plan.schema();
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
if let (_, Some(plan)) = &target.statement.statement {
let format = &target.result_column_format;
let schema = plan.schema();
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;

Ok(DescribePortalResponse::new(fields))
Ok(DescribePortalResponse::new(fields))
} else {
Ok(DescribePortalResponse::no_data())
}
}

async fn do_query<C>(
Expand Down Expand Up @@ -631,57 +643,60 @@ impl ExtendedQueryHandler for DfSessionService {
)));
}

let (_, plan) = &portal.statement.statement;

let param_types = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let param_values = df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types

let plan = plan
.clone()
.replace_params_with_values(&param_values)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
// &param_values
let optimised = self
.session_context
.state()
.optimize(&plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let dataframe = {
let timeout = Self::get_statement_timeout(client);
if let Some(timeout_duration) = timeout {
tokio::time::timeout(
timeout_duration,
self.session_context.execute_logical_plan(optimised),
)
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to statement timeout".to_string(),
)))
})?
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
} else {
self.session_context
.execute_logical_plan(optimised)
if let (_, Some(plan)) = &portal.statement.statement {
let param_types = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let param_values =
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types

let plan = plan
.clone()
.replace_params_with_values(&param_values)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
// &param_values
let optimised = self
.session_context
.state()
.optimize(&plan)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let dataframe = {
let timeout = Self::get_statement_timeout(client);
if let Some(timeout_duration) = timeout {
tokio::time::timeout(
timeout_duration,
self.session_context.execute_logical_plan(optimised),
)
.await
.map_err(|_| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // query_canceled error code
"canceling statement due to statement timeout".to_string(),
)))
})?
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
}
};
} else {
self.session_context
.execute_logical_plan(optimised)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
}
};

if query.starts_with("insert into") {
let resp = map_rows_affected_for_insert(&dataframe).await?;
if query.starts_with("insert into") {
let resp = map_rows_affected_for_insert(&dataframe).await?;

Ok(resp)
Ok(resp)
} else {
// For non-INSERT queries, return a regular Query response
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
Ok(Response::Query(resp))
}
} else {
// For non-INSERT queries, return a regular Query response
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
Ok(Response::Query(resp))
Ok(Response::EmptyQuery)
}
}
}
Expand Down Expand Up @@ -767,7 +782,7 @@ impl Parser {

#[async_trait]
impl QueryParser for Parser {
type Statement = (String, LogicalPlan);
type Statement = (String, Option<LogicalPlan>);

async fn parse_sql<C>(
&self,
Expand All @@ -782,13 +797,17 @@ impl QueryParser for Parser {
.try_shortcut_parse_plan(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
return Ok((sql.to_string(), plan));
return Ok((sql.to_string(), Some(plan)));
}

let mut statements = self
.sql_parser
.parse(sql)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
if statements.is_empty() {
return Ok((sql.to_string(), None));
}

let statement = statements.remove(0);

let query = statement.to_string();
Expand All @@ -799,7 +818,7 @@ impl QueryParser for Parser {
.statement_to_plan(Statement::Statement(Box::new(statement)))
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
Ok((query, logical_plan))
Ok((query, Some(logical_plan)))
}
}

Expand Down
Loading