Skip to content

Commit a9a5769

Browse files
authored
fix: add proper check to deal with empty query and multiple queries in simple query (#203)
1 parent 24368e4 commit a9a5769

File tree

1 file changed

+160
-141
lines changed

1 file changed

+160
-141
lines changed

datafusion-postgres/src/handlers.rs

Lines changed: 160 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use pgwire::api::auth::StartupHandler;
1414
use pgwire::api::portal::{Format, Portal};
1515
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
1616
use pgwire::api::results::{
17-
DescribePortalResponse, DescribeStatementResponse, FieldFormat, FieldInfo, QueryResponse,
18-
Response, Tag,
17+
DescribePortalResponse, DescribeResponse, DescribeStatementResponse, FieldFormat, FieldInfo,
18+
QueryResponse, Response, Tag,
1919
};
2020
use pgwire::api::stmt::QueryParser;
2121
use pgwire::api::stmt::StoredStatement;
@@ -438,97 +438,103 @@ impl SimpleQueryHandler for DfSessionService {
438438
return Ok(vec![resp]);
439439
}
440440

441-
let mut statements = self
441+
let statements = self
442442
.parser
443443
.sql_parser
444444
.parse(query)
445445
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
446446

447-
// TODO: deal with multiple statements
448-
let statement = statements.remove(0);
449-
450-
// TODO: improve statement check by using statement directly
451-
let query = statement.to_string();
452-
let query_lower = query.to_lowercase().trim().to_string();
453-
454-
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
455-
if !query_lower.starts_with("set")
456-
&& !query_lower.starts_with("begin")
457-
&& !query_lower.starts_with("commit")
458-
&& !query_lower.starts_with("rollback")
459-
&& !query_lower.starts_with("start")
460-
&& !query_lower.starts_with("end")
461-
&& !query_lower.starts_with("abort")
462-
&& !query_lower.starts_with("show")
463-
{
464-
self.check_query_permission(client, &query).await?;
465-
}
447+
// empty query
448+
if statements.is_empty() {
449+
return Ok(vec![Response::EmptyQuery]);
450+
}
451+
452+
let mut results = vec![];
453+
for statement in statements {
454+
// TODO: improve statement check by using statement directly
455+
let query = statement.to_string();
456+
let query_lower = query.to_lowercase().trim().to_string();
457+
458+
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
459+
if !query_lower.starts_with("set")
460+
&& !query_lower.starts_with("begin")
461+
&& !query_lower.starts_with("commit")
462+
&& !query_lower.starts_with("rollback")
463+
&& !query_lower.starts_with("start")
464+
&& !query_lower.starts_with("end")
465+
&& !query_lower.starts_with("abort")
466+
&& !query_lower.starts_with("show")
467+
{
468+
self.check_query_permission(client, &query).await?;
469+
}
466470

467-
if let Some(resp) = self
468-
.try_respond_set_statements(client, &query_lower)
469-
.await?
470-
{
471-
return Ok(vec![resp]);
472-
}
471+
if let Some(resp) = self
472+
.try_respond_set_statements(client, &query_lower)
473+
.await?
474+
{
475+
return Ok(vec![resp]);
476+
}
473477

474-
if let Some(resp) = self
475-
.try_respond_show_statements(client, &query_lower)
476-
.await?
477-
{
478-
return Ok(vec![resp]);
479-
}
478+
if let Some(resp) = self
479+
.try_respond_show_statements(client, &query_lower)
480+
.await?
481+
{
482+
return Ok(vec![resp]);
483+
}
480484

481-
// Check if we're in a failed transaction and block non-transaction
482-
// commands
483-
if client.transaction_status() == TransactionStatus::Error {
484-
return Err(PgWireError::UserError(Box::new(
485+
// Check if we're in a failed transaction and block non-transaction
486+
// commands
487+
if client.transaction_status() == TransactionStatus::Error {
488+
return Err(PgWireError::UserError(Box::new(
485489
pgwire::error::ErrorInfo::new(
486490
"ERROR".to_string(),
487491
"25P01".to_string(),
488492
"current transaction is aborted, commands ignored until end of transaction block".to_string(),
489493
),
490494
)));
491-
}
492-
493-
let df_result = {
494-
let timeout = Self::get_statement_timeout(client);
495-
if let Some(timeout_duration) = timeout {
496-
tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
497-
.await
498-
.map_err(|_| {
499-
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
500-
"ERROR".to_string(),
501-
"57014".to_string(), // query_canceled error code
502-
"canceling statement due to statement timeout".to_string(),
503-
)))
504-
})?
505-
} else {
506-
self.session_context.sql(&query).await
507495
}
508-
};
509496

510-
// Handle query execution errors and transaction state
511-
let df = match df_result {
512-
Ok(df) => df,
513-
Err(e) => {
514-
return Err(PgWireError::ApiError(Box::new(e)));
515-
}
516-
};
497+
let df_result = {
498+
let timeout = Self::get_statement_timeout(client);
499+
if let Some(timeout_duration) = timeout {
500+
tokio::time::timeout(timeout_duration, self.session_context.sql(&query))
501+
.await
502+
.map_err(|_| {
503+
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
504+
"ERROR".to_string(),
505+
"57014".to_string(), // query_canceled error code
506+
"canceling statement due to statement timeout".to_string(),
507+
)))
508+
})?
509+
} else {
510+
self.session_context.sql(&query).await
511+
}
512+
};
517513

518-
if query_lower.starts_with("insert into") {
519-
let resp = map_rows_affected_for_insert(&df).await?;
520-
Ok(vec![resp])
521-
} else {
522-
// For non-INSERT queries, return a regular Query response
523-
let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
524-
Ok(vec![Response::Query(resp)])
514+
// Handle query execution errors and transaction state
515+
let df = match df_result {
516+
Ok(df) => df,
517+
Err(e) => {
518+
return Err(PgWireError::ApiError(Box::new(e)));
519+
}
520+
};
521+
522+
if query_lower.starts_with("insert into") {
523+
let resp = map_rows_affected_for_insert(&df).await?;
524+
results.push(resp);
525+
} else {
526+
// For non-INSERT queries, return a regular Query response
527+
let resp = df::encode_dataframe(df, &Format::UnifiedText).await?;
528+
results.push(Response::Query(resp));
529+
}
525530
}
531+
Ok(results)
526532
}
527533
}
528534

529535
#[async_trait]
530536
impl ExtendedQueryHandler for DfSessionService {
531-
type Statement = (String, LogicalPlan);
537+
type Statement = (String, Option<LogicalPlan>);
532538
type QueryParser = Parser;
533539

534540
fn query_parser(&self) -> Arc<Self::QueryParser> {
@@ -543,25 +549,28 @@ impl ExtendedQueryHandler for DfSessionService {
543549
where
544550
C: ClientInfo + Unpin + Send + Sync,
545551
{
546-
let (_, plan) = &target.statement;
547-
let schema = plan.schema();
548-
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
549-
let params = plan
550-
.get_parameter_types()
551-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
552-
553-
let mut param_types = Vec::with_capacity(params.len());
554-
for param_type in ordered_param_types(&params).iter() {
555-
// Fixed: Use &params
556-
if let Some(datatype) = param_type {
557-
let pgtype = into_pg_type(datatype)?;
558-
param_types.push(pgtype);
559-
} else {
560-
param_types.push(Type::UNKNOWN);
552+
if let (_, Some(plan)) = &target.statement {
553+
let schema = plan.schema();
554+
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), &Format::UnifiedBinary)?;
555+
let params = plan
556+
.get_parameter_types()
557+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
558+
559+
let mut param_types = Vec::with_capacity(params.len());
560+
for param_type in ordered_param_types(&params).iter() {
561+
// Fixed: Use &params
562+
if let Some(datatype) = param_type {
563+
let pgtype = into_pg_type(datatype)?;
564+
param_types.push(pgtype);
565+
} else {
566+
param_types.push(Type::UNKNOWN);
567+
}
561568
}
562-
}
563569

564-
Ok(DescribeStatementResponse::new(param_types, fields))
570+
Ok(DescribeStatementResponse::new(param_types, fields))
571+
} else {
572+
Ok(DescribeStatementResponse::no_data())
573+
}
565574
}
566575

567576
async fn do_describe_portal<C>(
@@ -572,12 +581,15 @@ impl ExtendedQueryHandler for DfSessionService {
572581
where
573582
C: ClientInfo + Unpin + Send + Sync,
574583
{
575-
let (_, plan) = &target.statement.statement;
576-
let format = &target.result_column_format;
577-
let schema = plan.schema();
578-
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
584+
if let (_, Some(plan)) = &target.statement.statement {
585+
let format = &target.result_column_format;
586+
let schema = plan.schema();
587+
let fields = arrow_schema_to_pg_fields(schema.as_arrow(), format)?;
579588

580-
Ok(DescribePortalResponse::new(fields))
589+
Ok(DescribePortalResponse::new(fields))
590+
} else {
591+
Ok(DescribePortalResponse::no_data())
592+
}
581593
}
582594

583595
async fn do_query<C>(
@@ -631,57 +643,60 @@ impl ExtendedQueryHandler for DfSessionService {
631643
)));
632644
}
633645

634-
let (_, plan) = &portal.statement.statement;
635-
636-
let param_types = plan
637-
.get_parameter_types()
638-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
639-
640-
let param_values = df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
641-
642-
let plan = plan
643-
.clone()
644-
.replace_params_with_values(&param_values)
645-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
646-
// &param_values
647-
let optimised = self
648-
.session_context
649-
.state()
650-
.optimize(&plan)
651-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
652-
653-
let dataframe = {
654-
let timeout = Self::get_statement_timeout(client);
655-
if let Some(timeout_duration) = timeout {
656-
tokio::time::timeout(
657-
timeout_duration,
658-
self.session_context.execute_logical_plan(optimised),
659-
)
660-
.await
661-
.map_err(|_| {
662-
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
663-
"ERROR".to_string(),
664-
"57014".to_string(), // query_canceled error code
665-
"canceling statement due to statement timeout".to_string(),
666-
)))
667-
})?
668-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
669-
} else {
670-
self.session_context
671-
.execute_logical_plan(optimised)
646+
if let (_, Some(plan)) = &portal.statement.statement {
647+
let param_types = plan
648+
.get_parameter_types()
649+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
650+
651+
let param_values =
652+
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
653+
654+
let plan = plan
655+
.clone()
656+
.replace_params_with_values(&param_values)
657+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
658+
// &param_values
659+
let optimised = self
660+
.session_context
661+
.state()
662+
.optimize(&plan)
663+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
664+
665+
let dataframe = {
666+
let timeout = Self::get_statement_timeout(client);
667+
if let Some(timeout_duration) = timeout {
668+
tokio::time::timeout(
669+
timeout_duration,
670+
self.session_context.execute_logical_plan(optimised),
671+
)
672672
.await
673+
.map_err(|_| {
674+
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
675+
"ERROR".to_string(),
676+
"57014".to_string(), // query_canceled error code
677+
"canceling statement due to statement timeout".to_string(),
678+
)))
679+
})?
673680
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
674-
}
675-
};
681+
} else {
682+
self.session_context
683+
.execute_logical_plan(optimised)
684+
.await
685+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
686+
}
687+
};
676688

677-
if query.starts_with("insert into") {
678-
let resp = map_rows_affected_for_insert(&dataframe).await?;
689+
if query.starts_with("insert into") {
690+
let resp = map_rows_affected_for_insert(&dataframe).await?;
679691

680-
Ok(resp)
692+
Ok(resp)
693+
} else {
694+
// For non-INSERT queries, return a regular Query response
695+
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
696+
Ok(Response::Query(resp))
697+
}
681698
} else {
682-
// For non-INSERT queries, return a regular Query response
683-
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
684-
Ok(Response::Query(resp))
699+
Ok(Response::EmptyQuery)
685700
}
686701
}
687702
}
@@ -767,7 +782,7 @@ impl Parser {
767782

768783
#[async_trait]
769784
impl QueryParser for Parser {
770-
type Statement = (String, LogicalPlan);
785+
type Statement = (String, Option<LogicalPlan>);
771786

772787
async fn parse_sql<C>(
773788
&self,
@@ -782,13 +797,17 @@ impl QueryParser for Parser {
782797
.try_shortcut_parse_plan(sql)
783798
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
784799
{
785-
return Ok((sql.to_string(), plan));
800+
return Ok((sql.to_string(), Some(plan)));
786801
}
787802

788803
let mut statements = self
789804
.sql_parser
790805
.parse(sql)
791806
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
807+
if statements.is_empty() {
808+
return Ok((sql.to_string(), None));
809+
}
810+
792811
let statement = statements.remove(0);
793812

794813
let query = statement.to_string();
@@ -799,7 +818,7 @@ impl QueryParser for Parser {
799818
.statement_to_plan(Statement::Statement(Box::new(statement)))
800819
.await
801820
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
802-
Ok((query, logical_plan))
821+
Ok((query, Some(logical_plan)))
803822
}
804823
}
805824

0 commit comments

Comments
 (0)