@@ -14,8 +14,8 @@ use pgwire::api::auth::StartupHandler;
1414use pgwire:: api:: portal:: { Format , Portal } ;
1515use pgwire:: api:: query:: { ExtendedQueryHandler , SimpleQueryHandler } ;
1616use pgwire:: api:: results:: {
17- DescribePortalResponse , DescribeStatementResponse , FieldFormat , FieldInfo , QueryResponse ,
18- Response , Tag ,
17+ DescribePortalResponse , DescribeResponse , DescribeStatementResponse , FieldFormat , FieldInfo ,
18+ QueryResponse , Response , Tag ,
1919} ;
2020use pgwire:: api:: stmt:: QueryParser ;
2121use pgwire:: api:: stmt:: StoredStatement ;
@@ -438,97 +438,103 @@ impl SimpleQueryHandler for DfSessionService {
438438 return Ok ( vec ! [ resp] ) ;
439439 }
440440
441- let mut statements = self
441+ let statements = self
442442 . parser
443443 . sql_parser
444444 . parse ( query)
445445 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
446446
447- // TODO: deal with multiple statements
448- let statement = statements. remove ( 0 ) ;
449-
450- // TODO: improve statement check by using statement directly
451- let query = statement. to_string ( ) ;
452- let query_lower = query. to_lowercase ( ) . trim ( ) . to_string ( ) ;
453-
454- // Check permissions for the query (skip for SET, transaction, and SHOW statements)
455- if !query_lower. starts_with ( "set" )
456- && !query_lower. starts_with ( "begin" )
457- && !query_lower. starts_with ( "commit" )
458- && !query_lower. starts_with ( "rollback" )
459- && !query_lower. starts_with ( "start" )
460- && !query_lower. starts_with ( "end" )
461- && !query_lower. starts_with ( "abort" )
462- && !query_lower. starts_with ( "show" )
463- {
464- self . check_query_permission ( client, & query) . await ?;
465- }
447+ // empty query
448+ if statements. is_empty ( ) {
449+ return Ok ( vec ! [ Response :: EmptyQuery ] ) ;
450+ }
451+
452+ let mut results = vec ! [ ] ;
453+ for statement in statements {
454+ // TODO: improve statement check by using statement directly
455+ let query = statement. to_string ( ) ;
456+ let query_lower = query. to_lowercase ( ) . trim ( ) . to_string ( ) ;
457+
458+ // Check permissions for the query (skip for SET, transaction, and SHOW statements)
459+ if !query_lower. starts_with ( "set" )
460+ && !query_lower. starts_with ( "begin" )
461+ && !query_lower. starts_with ( "commit" )
462+ && !query_lower. starts_with ( "rollback" )
463+ && !query_lower. starts_with ( "start" )
464+ && !query_lower. starts_with ( "end" )
465+ && !query_lower. starts_with ( "abort" )
466+ && !query_lower. starts_with ( "show" )
467+ {
468+ self . check_query_permission ( client, & query) . await ?;
469+ }
466470
467- if let Some ( resp) = self
468- . try_respond_set_statements ( client, & query_lower)
469- . await ?
470- {
471- return Ok ( vec ! [ resp] ) ;
472- }
471+ if let Some ( resp) = self
472+ . try_respond_set_statements ( client, & query_lower)
473+ . await ?
474+ {
475+ return Ok ( vec ! [ resp] ) ;
476+ }
473477
474- if let Some ( resp) = self
475- . try_respond_show_statements ( client, & query_lower)
476- . await ?
477- {
478- return Ok ( vec ! [ resp] ) ;
479- }
478+ if let Some ( resp) = self
479+ . try_respond_show_statements ( client, & query_lower)
480+ . await ?
481+ {
482+ return Ok ( vec ! [ resp] ) ;
483+ }
480484
481- // Check if we're in a failed transaction and block non-transaction
482- // commands
483- if client. transaction_status ( ) == TransactionStatus :: Error {
484- return Err ( PgWireError :: UserError ( Box :: new (
485+ // Check if we're in a failed transaction and block non-transaction
486+ // commands
487+ if client. transaction_status ( ) == TransactionStatus :: Error {
488+ return Err ( PgWireError :: UserError ( Box :: new (
485489 pgwire:: error:: ErrorInfo :: new (
486490 "ERROR" . to_string ( ) ,
487491 "25P01" . to_string ( ) ,
488492 "current transaction is aborted, commands ignored until end of transaction block" . to_string ( ) ,
489493 ) ,
490494 ) ) ) ;
491- }
492-
493- let df_result = {
494- let timeout = Self :: get_statement_timeout ( client) ;
495- if let Some ( timeout_duration) = timeout {
496- tokio:: time:: timeout ( timeout_duration, self . session_context . sql ( & query) )
497- . await
498- . map_err ( |_| {
499- PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
500- "ERROR" . to_string ( ) ,
501- "57014" . to_string ( ) , // query_canceled error code
502- "canceling statement due to statement timeout" . to_string ( ) ,
503- ) ) )
504- } ) ?
505- } else {
506- self . session_context . sql ( & query) . await
507495 }
508- } ;
509496
510- // Handle query execution errors and transaction state
511- let df = match df_result {
512- Ok ( df) => df,
513- Err ( e) => {
514- return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
515- }
516- } ;
497+ let df_result = {
498+ let timeout = Self :: get_statement_timeout ( client) ;
499+ if let Some ( timeout_duration) = timeout {
500+ tokio:: time:: timeout ( timeout_duration, self . session_context . sql ( & query) )
501+ . await
502+ . map_err ( |_| {
503+ PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
504+ "ERROR" . to_string ( ) ,
505+ "57014" . to_string ( ) , // query_canceled error code
506+ "canceling statement due to statement timeout" . to_string ( ) ,
507+ ) ) )
508+ } ) ?
509+ } else {
510+ self . session_context . sql ( & query) . await
511+ }
512+ } ;
517513
518- if query_lower. starts_with ( "insert into" ) {
519- let resp = map_rows_affected_for_insert ( & df) . await ?;
520- Ok ( vec ! [ resp] )
521- } else {
522- // For non-INSERT queries, return a regular Query response
523- let resp = df:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
524- Ok ( vec ! [ Response :: Query ( resp) ] )
514+ // Handle query execution errors and transaction state
515+ let df = match df_result {
516+ Ok ( df) => df,
517+ Err ( e) => {
518+ return Err ( PgWireError :: ApiError ( Box :: new ( e) ) ) ;
519+ }
520+ } ;
521+
522+ if query_lower. starts_with ( "insert into" ) {
523+ let resp = map_rows_affected_for_insert ( & df) . await ?;
524+ results. push ( resp) ;
525+ } else {
526+ // For non-INSERT queries, return a regular Query response
527+ let resp = df:: encode_dataframe ( df, & Format :: UnifiedText ) . await ?;
528+ results. push ( Response :: Query ( resp) ) ;
529+ }
525530 }
531+ Ok ( results)
526532 }
527533}
528534
529535#[ async_trait]
530536impl ExtendedQueryHandler for DfSessionService {
531- type Statement = ( String , LogicalPlan ) ;
537+ type Statement = ( String , Option < LogicalPlan > ) ;
532538 type QueryParser = Parser ;
533539
534540 fn query_parser ( & self ) -> Arc < Self :: QueryParser > {
@@ -543,25 +549,28 @@ impl ExtendedQueryHandler for DfSessionService {
543549 where
544550 C : ClientInfo + Unpin + Send + Sync ,
545551 {
546- let ( _, plan) = & target. statement ;
547- let schema = plan. schema ( ) ;
548- let fields = arrow_schema_to_pg_fields ( schema. as_arrow ( ) , & Format :: UnifiedBinary ) ?;
549- let params = plan
550- . get_parameter_types ( )
551- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
552-
553- let mut param_types = Vec :: with_capacity ( params. len ( ) ) ;
554- for param_type in ordered_param_types ( & params) . iter ( ) {
555- // Fixed: Use ¶ms
556- if let Some ( datatype) = param_type {
557- let pgtype = into_pg_type ( datatype) ?;
558- param_types. push ( pgtype) ;
559- } else {
560- param_types. push ( Type :: UNKNOWN ) ;
552+ if let ( _, Some ( plan) ) = & target. statement {
553+ let schema = plan. schema ( ) ;
554+ let fields = arrow_schema_to_pg_fields ( schema. as_arrow ( ) , & Format :: UnifiedBinary ) ?;
555+ let params = plan
556+ . get_parameter_types ( )
557+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
558+
559+ let mut param_types = Vec :: with_capacity ( params. len ( ) ) ;
560+ for param_type in ordered_param_types ( & params) . iter ( ) {
561+ // Fixed: Use ¶ms
562+ if let Some ( datatype) = param_type {
563+ let pgtype = into_pg_type ( datatype) ?;
564+ param_types. push ( pgtype) ;
565+ } else {
566+ param_types. push ( Type :: UNKNOWN ) ;
567+ }
561568 }
562- }
563569
564- Ok ( DescribeStatementResponse :: new ( param_types, fields) )
570+ Ok ( DescribeStatementResponse :: new ( param_types, fields) )
571+ } else {
572+ Ok ( DescribeStatementResponse :: no_data ( ) )
573+ }
565574 }
566575
567576 async fn do_describe_portal < C > (
@@ -572,12 +581,15 @@ impl ExtendedQueryHandler for DfSessionService {
572581 where
573582 C : ClientInfo + Unpin + Send + Sync ,
574583 {
575- let ( _, plan) = & target. statement . statement ;
576- let format = & target. result_column_format ;
577- let schema = plan. schema ( ) ;
578- let fields = arrow_schema_to_pg_fields ( schema. as_arrow ( ) , format) ?;
584+ if let ( _, Some ( plan) ) = & target. statement . statement {
585+ let format = & target. result_column_format ;
586+ let schema = plan. schema ( ) ;
587+ let fields = arrow_schema_to_pg_fields ( schema. as_arrow ( ) , format) ?;
579588
580- Ok ( DescribePortalResponse :: new ( fields) )
589+ Ok ( DescribePortalResponse :: new ( fields) )
590+ } else {
591+ Ok ( DescribePortalResponse :: no_data ( ) )
592+ }
581593 }
582594
583595 async fn do_query < C > (
@@ -631,57 +643,60 @@ impl ExtendedQueryHandler for DfSessionService {
631643 ) ) ) ;
632644 }
633645
634- let ( _, plan) = & portal. statement . statement ;
635-
636- let param_types = plan
637- . get_parameter_types ( )
638- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
639-
640- let param_values = df:: deserialize_parameters ( portal, & ordered_param_types ( & param_types) ) ?; // Fixed: Use ¶m_types
641-
642- let plan = plan
643- . clone ( )
644- . replace_params_with_values ( & param_values)
645- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?; // Fixed: Use
646- // ¶m_values
647- let optimised = self
648- . session_context
649- . state ( )
650- . optimize ( & plan)
651- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
652-
653- let dataframe = {
654- let timeout = Self :: get_statement_timeout ( client) ;
655- if let Some ( timeout_duration) = timeout {
656- tokio:: time:: timeout (
657- timeout_duration,
658- self . session_context . execute_logical_plan ( optimised) ,
659- )
660- . await
661- . map_err ( |_| {
662- PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
663- "ERROR" . to_string ( ) ,
664- "57014" . to_string ( ) , // query_canceled error code
665- "canceling statement due to statement timeout" . to_string ( ) ,
666- ) ) )
667- } ) ?
668- . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
669- } else {
670- self . session_context
671- . execute_logical_plan ( optimised)
646+ if let ( _, Some ( plan) ) = & portal. statement . statement {
647+ let param_types = plan
648+ . get_parameter_types ( )
649+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
650+
651+ let param_values =
652+ df:: deserialize_parameters ( portal, & ordered_param_types ( & param_types) ) ?; // Fixed: Use ¶m_types
653+
654+ let plan = plan
655+ . clone ( )
656+ . replace_params_with_values ( & param_values)
657+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?; // Fixed: Use
658+ // ¶m_values
659+ let optimised = self
660+ . session_context
661+ . state ( )
662+ . optimize ( & plan)
663+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
664+
665+ let dataframe = {
666+ let timeout = Self :: get_statement_timeout ( client) ;
667+ if let Some ( timeout_duration) = timeout {
668+ tokio:: time:: timeout (
669+ timeout_duration,
670+ self . session_context . execute_logical_plan ( optimised) ,
671+ )
672672 . await
673+ . map_err ( |_| {
674+ PgWireError :: UserError ( Box :: new ( pgwire:: error:: ErrorInfo :: new (
675+ "ERROR" . to_string ( ) ,
676+ "57014" . to_string ( ) , // query_canceled error code
677+ "canceling statement due to statement timeout" . to_string ( ) ,
678+ ) ) )
679+ } ) ?
673680 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
674- }
675- } ;
681+ } else {
682+ self . session_context
683+ . execute_logical_plan ( optimised)
684+ . await
685+ . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
686+ }
687+ } ;
676688
677- if query. starts_with ( "insert into" ) {
678- let resp = map_rows_affected_for_insert ( & dataframe) . await ?;
689+ if query. starts_with ( "insert into" ) {
690+ let resp = map_rows_affected_for_insert ( & dataframe) . await ?;
679691
680- Ok ( resp)
692+ Ok ( resp)
693+ } else {
694+ // For non-INSERT queries, return a regular Query response
695+ let resp = df:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
696+ Ok ( Response :: Query ( resp) )
697+ }
681698 } else {
682- // For non-INSERT queries, return a regular Query response
683- let resp = df:: encode_dataframe ( dataframe, & portal. result_column_format ) . await ?;
684- Ok ( Response :: Query ( resp) )
699+ Ok ( Response :: EmptyQuery )
685700 }
686701 }
687702}
@@ -767,7 +782,7 @@ impl Parser {
767782
768783#[ async_trait]
769784impl QueryParser for Parser {
770- type Statement = ( String , LogicalPlan ) ;
785+ type Statement = ( String , Option < LogicalPlan > ) ;
771786
772787 async fn parse_sql < C > (
773788 & self ,
@@ -782,13 +797,17 @@ impl QueryParser for Parser {
782797 . try_shortcut_parse_plan ( sql)
783798 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?
784799 {
785- return Ok ( ( sql. to_string ( ) , plan) ) ;
800+ return Ok ( ( sql. to_string ( ) , Some ( plan) ) ) ;
786801 }
787802
788803 let mut statements = self
789804 . sql_parser
790805 . parse ( sql)
791806 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
807+ if statements. is_empty ( ) {
808+ return Ok ( ( sql. to_string ( ) , None ) ) ;
809+ }
810+
792811 let statement = statements. remove ( 0 ) ;
793812
794813 let query = statement. to_string ( ) ;
@@ -799,7 +818,7 @@ impl QueryParser for Parser {
799818 . statement_to_plan ( Statement :: Statement ( Box :: new ( statement) ) )
800819 . await
801820 . map_err ( |e| PgWireError :: ApiError ( Box :: new ( e) ) ) ?;
802- Ok ( ( query, logical_plan) )
821+ Ok ( ( query, Some ( logical_plan) ) )
803822 }
804823}
805824
0 commit comments