diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index fcbded3..683afc1 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -516,27 +516,8 @@ impl SimpleQueryHandler for DfSessionService { }; if query_lower.starts_with("insert into") { - // For INSERT queries, we need to execute the query to get the row count - // and return an Execution response with the proper tag - let result = df - .clone() - .collect() - .await - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; - - // Extract count field from the first batch - let rows_affected = result - .first() - .and_then(|batch| batch.column_by_name("count")) - .and_then(|col| { - col.as_any() - .downcast_ref::() - }) - .map_or(0, |array| array.value(0) as usize); - - // Create INSERT tag with the affected row count - let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected); - Ok(vec![Response::Execution(tag)]) + let resp = map_rows_affected_for_insert(&df).await?; + Ok(vec![resp]) } else { // For non-INSERT queries, return a regular Query response let resp = df::encode_dataframe(df, &Format::UnifiedText).await?; @@ -692,11 +673,43 @@ impl ExtendedQueryHandler for DfSessionService { .map_err(|e| PgWireError::ApiError(Box::new(e)))? } }; - let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?; - Ok(Response::Query(resp)) + + if query.starts_with("insert into") { + let resp = map_rows_affected_for_insert(&dataframe).await?; + + Ok(resp) + } else { + // For non-INSERT queries, return a regular Query response + let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?; + Ok(Response::Query(resp)) + } } } +async fn map_rows_affected_for_insert<'a>(df: &DataFrame) -> PgWireResult> { + // For INSERT queries, we need to execute the query to get the row count + // and return an Execution response with the proper tag + let result = df + .clone() + .collect() + .await + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + + // Extract count field from the first batch + let rows_affected = result + .first() + .and_then(|batch| batch.column_by_name("count")) + .and_then(|col| { + col.as_any() + .downcast_ref::() + }) + .map_or(0, |array| array.value(0) as usize); + + // Create INSERT tag with the affected row count + let tag = Tag::new("INSERT").with_oid(0).with_rows(rows_affected); + Ok(Response::Execution(tag)) +} + pub struct Parser { session_context: Arc, sql_parser: PostgresCompatibilityParser,