diff --git a/src/query/expression/src/utils/udf_client.rs b/src/query/expression/src/utils/udf_client.rs index 953314672b703..c1510509110cb 100644 --- a/src/query/expression/src/utils/udf_client.rs +++ b/src/query/expression/src/utils/udf_client.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::error::Error as StdError; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; @@ -20,8 +21,10 @@ use std::time::Instant; use arrow_array::RecordBatch; use arrow_flight::decode::FlightRecordBatchStream; use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::error::FlightError; use arrow_flight::flight_service_client::FlightServiceClient; use arrow_flight::FlightDescriptor; +use arrow_schema::ArrowError; use arrow_select::concat::concat_batches; use databend_common_base::headers::HEADER_FUNCTION; use databend_common_base::headers::HEADER_FUNCTION_HANDLER; @@ -49,6 +52,7 @@ use tonic::transport::channel::Channel; use tonic::transport::ClientTlsConfig; use tonic::transport::Endpoint; use tonic::Request; +use tonic::Status; use crate::types::DataType; use crate::variant_transform::contains_variant; @@ -64,6 +68,26 @@ const UDF_KEEP_ALIVE_TIMEOUT_SEC: u64 = 20; // 4MB by default, we use 16G // max_encoding_message_size is usize::max by default const MAX_DECODING_MESSAGE_SIZE: usize = 16 * 1024 * 1024 * 1024; +// These lowercase fragments map brittle transport errors to friendlier messaging. +// Keep the list up to date as dependencies evolve or add new patterns when gaps appear. +const TRANSPORT_ERROR_SNIPPETS: &[&str] = &[ + "h2 protocol error", + "broken pipe", + "connection reset", + "error reading a body from connection", + "connection refused", + "network is unreachable", + "no route to host", +]; + +#[derive(Debug)] +enum FlightDecodeIssue { + TransportInterrupted, + ServerStatus(String), + SchemaMismatch, + MalformedData, + Other, +} #[derive(Debug, Clone)] pub struct UDFFlightClient { @@ -336,8 +360,7 @@ impl UDFFlightClient { if result_fields[0].data_type() != return_type { return Err(ErrorCode::UDFSchemaMismatch(format!( - "UDF server return incorrect type, expected: {}, but got: {}", - return_type, + "The user-defined function \"{func_name}\" returned an unexpected schema. Expected result type {return_type}, but got {}.", result_fields[0].data_type() ))); } @@ -380,11 +403,7 @@ impl UDFFlightClient { let record_batch_stream = FlightRecordBatchStream::new_from_flight_data( flight_data_stream.map_err(|err| err.into()), ) - .map_err(|err| { - ErrorCode::UDFDataError(format!( - "Decode record batch failed on UDF function {func_name}: {err}" - )) - }); + .map_err(|err| handle_flight_decode_error(func_name, err)); let batches: Vec = record_batch_stream.try_collect().await?; if batches.is_empty() { @@ -399,6 +418,88 @@ impl UDFFlightClient { } } +fn handle_flight_decode_error(func_name: &str, err: FlightError) -> ErrorCode { + let issue = classify_flight_error(&err); + let err_text = err.to_string(); + + match issue { + FlightDecodeIssue::TransportInterrupted => ErrorCode::UDFDataError(format!( + "The user-defined function \"{func_name}\" stopped responding before it finished. Retry the query; if it keeps failing, ensure the UDF server is running or review its logs. (details: {err_text})" + )), + FlightDecodeIssue::ServerStatus(status) => ErrorCode::UDFDataError(format!( + "The user-defined function \"{func_name}\" reported an error: {status}. Review the UDF server logs." + )), + FlightDecodeIssue::SchemaMismatch => ErrorCode::UDFDataError(format!( + "The user-defined function \"{func_name}\" returned an unexpected schema. Ensure the UDF definition matches the server output. (details: {err_text})" + )), + FlightDecodeIssue::MalformedData => ErrorCode::UDFDataError(format!( + "The user-defined function \"{func_name}\" returned data that Databend could not parse. Check the UDF implementation or its logs. (details: {err_text})" + )), + FlightDecodeIssue::Other => ErrorCode::UDFDataError(format!( + "Decode record batch failed on UDF function \"{func_name}\": {err_text}" + )), + } +} + +fn classify_flight_error(err: &FlightError) -> FlightDecodeIssue { + match err { + FlightError::Arrow(arrow_err) => classify_arrow_error(arrow_err), + FlightError::Tonic(status) => classify_status(status), + FlightError::ExternalError(source) => classify_external_error(source.as_ref()), + FlightError::ProtocolError(_) | FlightError::DecodeError(_) => { + FlightDecodeIssue::MalformedData + } + FlightError::NotYetImplemented(_) => FlightDecodeIssue::Other, + } +} + +fn classify_arrow_error(err: &ArrowError) -> FlightDecodeIssue { + match err { + ArrowError::SchemaError(_) => FlightDecodeIssue::SchemaMismatch, + ArrowError::ExternalError(source) => classify_external_error(source.as_ref()), + ArrowError::IoError(message, _) => classify_error_message(message), + ArrowError::ParseError(_) + | ArrowError::InvalidArgumentError(_) + | ArrowError::ComputeError(_) + | ArrowError::JsonError(_) + | ArrowError::CsvError(_) + | ArrowError::IpcError(_) + | ArrowError::CDataInterface(_) + | ArrowError::ParquetError(_) => FlightDecodeIssue::MalformedData, + _ => FlightDecodeIssue::Other, + } +} + +fn classify_status(status: &Status) -> FlightDecodeIssue { + classify_error_message(status.message()) +} + +fn classify_external_error(error: &(dyn StdError + Send + Sync + 'static)) -> FlightDecodeIssue { + if let Some(arrow_err) = error.downcast_ref::() { + classify_arrow_error(arrow_err) + } else if let Some(status) = error.downcast_ref::() { + classify_status(status) + } else if let Some(io_error) = error.downcast_ref::() { + classify_error_message(&io_error.to_string()) + } else { + classify_error_message(&error.to_string()) + } +} + +fn classify_error_message(message: &str) -> FlightDecodeIssue { + if is_transport_error_message(message) { + FlightDecodeIssue::TransportInterrupted + } else { + FlightDecodeIssue::ServerStatus(message.to_string()) + } +} + +pub fn is_transport_error_message(message: &str) -> bool { + let lower = message.to_ascii_lowercase(); + TRANSPORT_ERROR_SNIPPETS + .iter() + .any(|snippet| lower.contains(snippet)) +} pub fn error_kind(message: &str) -> &str { let message = message.to_ascii_lowercase(); if message.contains("timeout") || message.contains("timedout") { @@ -418,3 +519,70 @@ pub fn error_kind(message: &str) -> &str { "Other" } } + +#[cfg(test)] +mod tests { + use tonic::Code; + + use super::*; + + #[test] + fn transport_error_returns_interrupt_hint() { + let err = handle_flight_decode_error( + "test_udf", + FlightError::Tonic(Box::new(Status::new( + Code::Internal, + "h2 protocol error: error reading a body from connection", + ))), + ); + let message = err.message(); + assert!( + message.contains("stopped responding before it finished"), + "unexpected transport hint: {message}" + ); + } + + #[test] + fn server_status_is_preserved() { + let err = handle_flight_decode_error( + "test_udf", + FlightError::Tonic(Box::new(Status::new( + Code::Internal, + "remote handler returned validation error", + ))), + ); + let message = err.message(); + assert!( + message.contains("reported an error: remote handler returned validation error"), + "unexpected server status message: {message}" + ); + } + + #[test] + fn schema_mismatch_detected() { + let err = handle_flight_decode_error( + "test_udf", + FlightError::Arrow(ArrowError::SchemaError( + "expected Int32, got Utf8".to_string(), + )), + ); + let message = err.message(); + assert!( + message.contains("returned an unexpected schema"), + "schema mismatch hint missing: {message}" + ); + } + + #[test] + fn malformed_data_reported() { + let err = handle_flight_decode_error( + "test_udf", + FlightError::Arrow(ArrowError::ParseError("bad payload".to_string())), + ); + let message = err.message(); + assert!( + message.contains("could not parse"), + "malformed data hint missing: {message}" + ); + } +} diff --git a/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs b/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs index 8829cafdb2dd8..a1e7b4344443a 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs @@ -24,6 +24,7 @@ use databend_common_catalog::table_context::TableContext; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::udf_client::error_kind; +use databend_common_expression::udf_client::is_transport_error_message; use databend_common_expression::udf_client::UDFFlightClient; use databend_common_expression::BlockEntry; use databend_common_expression::ColumnBuilder; @@ -156,7 +157,7 @@ fn retry_on(err: &databend_common_exception::ErrorCode) -> bool { if err.code() == ErrorCode::U_D_F_DATA_ERROR { let message = err.message(); // this means the server can't handle the request in 60s - if message.contains("h2 protocol error") { + if is_transport_error_message(&message) { return false; } } diff --git a/src/query/service/tests/it/pipelines/mod.rs b/src/query/service/tests/it/pipelines/mod.rs index 9048ca928cbc9..ea80379fa4c10 100644 --- a/src/query/service/tests/it/pipelines/mod.rs +++ b/src/query/service/tests/it/pipelines/mod.rs @@ -15,3 +15,4 @@ mod executor; mod filter; mod transforms; +mod udf_transport; diff --git a/src/query/service/tests/it/pipelines/udf_transport.rs b/src/query/service/tests/it/pipelines/udf_transport.rs new file mode 100644 index 0000000000000..3868ab4fc6684 --- /dev/null +++ b/src/query/service/tests/it/pipelines/udf_transport.rs @@ -0,0 +1,294 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::pin::Pin; +use std::sync::Arc; + +use arrow_array::RecordBatch; +use arrow_array::StringArray; +use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::flight_service_server::FlightService; +use arrow_flight::flight_service_server::FlightServiceServer; +use arrow_flight::Action; +use arrow_flight::ActionType; +use arrow_flight::Criteria; +use arrow_flight::Empty; +use arrow_flight::FlightData; +use arrow_flight::FlightDescriptor; +use arrow_flight::FlightEndpoint; +use arrow_flight::FlightInfo; +use arrow_flight::HandshakeRequest; +use arrow_flight::HandshakeResponse; +use arrow_flight::PollInfo; +use arrow_flight::PutResult; +use arrow_flight::SchemaResult; +use arrow_flight::Ticket; +use arrow_schema::DataType as ArrowDataType; +use arrow_schema::Field; +use arrow_schema::Schema; +use databend_common_base::base::tokio; +use databend_common_exception::Result; +use databend_common_expression::types::DataType; +use databend_common_expression::udf_client::UDFFlightClient; +use databend_common_expression::BlockEntry; +use databend_common_expression::Column; +use futures::stream; +use futures::Stream; +use futures::StreamExt; +use tokio::time::timeout; +use tokio_stream::wrappers::TcpListenerStream; +use tonic::transport::Server; +use tonic::Request; +use tonic::Response; +use tonic::Status; + +#[derive(Clone)] +enum MockMode { + Transport, + ServerStatus(String), + MalformedData, + SchemaMismatch, +} + +struct MockFlightService { + mode: MockMode, +} + +type FlightStream = Pin> + Send + 'static>>; + +#[tonic::async_trait] +impl FlightService for MockFlightService { + type HandshakeStream = FlightStream; + + async fn handshake( + &self, + _request: Request>, + ) -> std::result::Result, Status> { + Err(Status::unimplemented("handshake not supported")) + } + + type ListFlightsStream = FlightStream; + + async fn list_flights( + &self, + _request: Request, + ) -> std::result::Result, Status> { + Err(Status::unimplemented("list flights not supported")) + } + + async fn get_flight_info( + &self, + request: Request, + ) -> std::result::Result, Status> { + if matches!(self.mode, MockMode::SchemaMismatch) { + let schema = Schema::new(vec![ + Field::new("arg1", ArrowDataType::Null, true), + Field::new("result", ArrowDataType::Utf8, false), + ]); + let descriptor = request.into_inner(); + let info = FlightInfo::new() + .try_with_schema(&schema) + .map_err(|e| Status::internal(e.to_string()))? + .with_descriptor(descriptor) + .with_endpoint(FlightEndpoint::new()); + Ok(Response::new(info)) + } else { + Err(Status::unimplemented("get flight info not supported")) + } + } + + async fn poll_flight_info( + &self, + _request: Request, + ) -> std::result::Result, Status> { + Err(Status::unimplemented("poll flight info not supported")) + } + + async fn get_schema( + &self, + _request: Request, + ) -> std::result::Result, Status> { + match &self.mode { + MockMode::SchemaMismatch => { + let err = Status::internal("schema mismatch: expected Int32, got Utf8"); + Err(err) + } + _ => Err(Status::unimplemented("get schema not supported")), + } + } + + type DoGetStream = FlightStream; + + async fn do_get( + &self, + _request: Request, + ) -> std::result::Result, Status> { + Err(Status::unimplemented("do_get not supported")) + } + + type DoPutStream = FlightStream; + + async fn do_put( + &self, + _request: Request>, + ) -> std::result::Result, Status> { + Err(Status::unimplemented("do_put not supported")) + } + + type DoExchangeStream = FlightStream; + + async fn do_exchange( + &self, + _request: Request>, + ) -> std::result::Result, Status> { + let stream: FlightStream = match &self.mode { + MockMode::Transport => Box::pin(stream::once(async { + Err(Status::internal( + "h2 protocol error: error reading a body from connection", + )) + })), + MockMode::ServerStatus(message) => { + let message = message.clone(); + Box::pin(stream::once(async move { + Err(Status::invalid_argument(message)) + })) + } + MockMode::MalformedData => { + let invalid = FlightData::default(); + Box::pin(stream::iter(vec![Ok(invalid)])) + } + MockMode::SchemaMismatch => { + let schema = Arc::new(Schema::new(vec![Field::new( + "result", + ArrowDataType::Utf8, + false, + )])); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(StringArray::from(vec![ + "hello", + ]))]) + .expect("build record batch"); + let stream = stream::iter(vec![Ok(batch)]); + let encoder = FlightDataEncoderBuilder::new() + .with_schema(schema) + .build(stream) + .map(|res| res.map_err(|e| Status::internal(e.to_string()))); + Box::pin(encoder) + } + }; + Ok(Response::new(stream)) + } + + type DoActionStream = FlightStream; + + async fn do_action( + &self, + _request: Request, + ) -> std::result::Result, Status> { + Err(Status::unimplemented("do_action not supported")) + } + + type ListActionsStream = FlightStream; + + async fn list_actions( + &self, + _request: Request, + ) -> std::result::Result, Status> { + Err(Status::unimplemented("list_actions not supported")) + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn transport_error_returns_friendly_hint() -> Result<()> { + let message = run_mock_exchange(MockMode::Transport).await?; + assert!( + message.contains("stopped responding before it finished"), + "unexpected transport message: {message}" + ); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn server_status_error_includes_message() -> Result<()> { + let message = run_mock_exchange(MockMode::ServerStatus( + "remote validation failed".to_string(), + )) + .await?; + assert!( + message.contains("reported an error: remote validation failed"), + "unexpected server status message: {message}" + ); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn malformed_data_returns_parse_hint() -> Result<()> { + let message = run_mock_exchange(MockMode::MalformedData).await?; + assert!( + message.contains("could not parse"), + "unexpected malformed data message: {message}" + ); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn schema_mismatch_returns_schema_hint() -> Result<()> { + let message = run_mock_exchange(MockMode::SchemaMismatch).await?; + assert!( + message.contains("returned an unexpected schema"), + "unexpected schema mismatch message: {message}" + ); + Ok(()) +} + +async fn run_mock_exchange(mode: MockMode) -> Result { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; + let address = listener.local_addr()?; + let incoming = TcpListenerStream::new(listener); + + let service = MockFlightService { mode }; + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + let server = databend_common_base::runtime::spawn(async move { + let _ = Server::builder() + .add_service(FlightServiceServer::new(service)) + .serve_with_incoming_shutdown(incoming, async move { + let _ = shutdown_rx.await; + }) + .await; + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + let endpoint = + UDFFlightClient::build_endpoint(&format!("http://{}", address), 3, 3, "udf-client-test")?; + let mut client = UDFFlightClient::connect("mock_udf", endpoint, 3, 1024).await?; + + let num_rows = 1; + let args = vec![BlockEntry::from(Column::Null { len: num_rows })]; + let return_type = DataType::Null; + let result = timeout( + std::time::Duration::from_secs(5), + client.do_exchange("mock_udf", "mock_handler", num_rows, args, &return_type), + ) + .await + .expect("do_exchange future timed out"); + + let _ = shutdown_tx.send(()); + let _ = timeout(std::time::Duration::from_secs(5), server) + .await + .expect("server shutdown timed out"); + + let err = result.expect_err("expected failure"); + Ok(err.message().to_string()) +}