Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 175 additions & 7 deletions src/query/expression/src/utils/udf_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::error::Error as StdError;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
Expand All @@ -20,8 +21,10 @@ use std::time::Instant;
use arrow_array::RecordBatch;
use arrow_flight::decode::FlightRecordBatchStream;
use arrow_flight::encode::FlightDataEncoderBuilder;
use arrow_flight::error::FlightError;
use arrow_flight::flight_service_client::FlightServiceClient;
use arrow_flight::FlightDescriptor;
use arrow_schema::ArrowError;
use arrow_select::concat::concat_batches;
use databend_common_base::headers::HEADER_FUNCTION;
use databend_common_base::headers::HEADER_FUNCTION_HANDLER;
Expand Down Expand Up @@ -49,6 +52,7 @@ use tonic::transport::channel::Channel;
use tonic::transport::ClientTlsConfig;
use tonic::transport::Endpoint;
use tonic::Request;
use tonic::Status;

use crate::types::DataType;
use crate::variant_transform::contains_variant;
Expand All @@ -64,6 +68,26 @@ const UDF_KEEP_ALIVE_TIMEOUT_SEC: u64 = 20;
// 4MB by default, we use 16G
// max_encoding_message_size is usize::max by default
const MAX_DECODING_MESSAGE_SIZE: usize = 16 * 1024 * 1024 * 1024;
// These lowercase fragments map brittle transport errors to friendlier messaging.
// Keep the list up to date as dependencies evolve or add new patterns when gaps appear.
const TRANSPORT_ERROR_SNIPPETS: &[&str] = &[
"h2 protocol error",
"broken pipe",
"connection reset",
"error reading a body from connection",
"connection refused",
"network is unreachable",
"no route to host",
];

#[derive(Debug)]
enum FlightDecodeIssue {
TransportInterrupted,
ServerStatus(String),
SchemaMismatch,
MalformedData,
Other,
}

#[derive(Debug, Clone)]
pub struct UDFFlightClient {
Expand Down Expand Up @@ -336,8 +360,7 @@ impl UDFFlightClient {

if result_fields[0].data_type() != return_type {
return Err(ErrorCode::UDFSchemaMismatch(format!(
"UDF server return incorrect type, expected: {}, but got: {}",
return_type,
"The user-defined function \"{func_name}\" returned an unexpected schema. Expected result type {return_type}, but got {}.",
result_fields[0].data_type()
)));
}
Expand Down Expand Up @@ -380,11 +403,7 @@ impl UDFFlightClient {
let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
flight_data_stream.map_err(|err| err.into()),
)
.map_err(|err| {
ErrorCode::UDFDataError(format!(
"Decode record batch failed on UDF function {func_name}: {err}"
))
});
.map_err(|err| handle_flight_decode_error(func_name, err));

let batches: Vec<RecordBatch> = record_batch_stream.try_collect().await?;
if batches.is_empty() {
Expand All @@ -399,6 +418,88 @@ impl UDFFlightClient {
}
}

fn handle_flight_decode_error(func_name: &str, err: FlightError) -> ErrorCode {
let issue = classify_flight_error(&err);
let err_text = err.to_string();

match issue {
FlightDecodeIssue::TransportInterrupted => ErrorCode::UDFDataError(format!(
"The user-defined function \"{func_name}\" stopped responding before it finished. Retry the query; if it keeps failing, ensure the UDF server is running or review its logs. (details: {err_text})"
)),
FlightDecodeIssue::ServerStatus(status) => ErrorCode::UDFDataError(format!(
"The user-defined function \"{func_name}\" reported an error: {status}. Review the UDF server logs."
)),
FlightDecodeIssue::SchemaMismatch => ErrorCode::UDFDataError(format!(
"The user-defined function \"{func_name}\" returned an unexpected schema. Ensure the UDF definition matches the server output. (details: {err_text})"
)),
FlightDecodeIssue::MalformedData => ErrorCode::UDFDataError(format!(
"The user-defined function \"{func_name}\" returned data that Databend could not parse. Check the UDF implementation or its logs. (details: {err_text})"
)),
FlightDecodeIssue::Other => ErrorCode::UDFDataError(format!(
"Decode record batch failed on UDF function \"{func_name}\": {err_text}"
)),
}
}

fn classify_flight_error(err: &FlightError) -> FlightDecodeIssue {
match err {
FlightError::Arrow(arrow_err) => classify_arrow_error(arrow_err),
FlightError::Tonic(status) => classify_status(status),
FlightError::ExternalError(source) => classify_external_error(source.as_ref()),
FlightError::ProtocolError(_) | FlightError::DecodeError(_) => {
FlightDecodeIssue::MalformedData
}
FlightError::NotYetImplemented(_) => FlightDecodeIssue::Other,
}
}

fn classify_arrow_error(err: &ArrowError) -> FlightDecodeIssue {
match err {
ArrowError::SchemaError(_) => FlightDecodeIssue::SchemaMismatch,
ArrowError::ExternalError(source) => classify_external_error(source.as_ref()),
ArrowError::IoError(message, _) => classify_error_message(message),
ArrowError::ParseError(_)
| ArrowError::InvalidArgumentError(_)
| ArrowError::ComputeError(_)
| ArrowError::JsonError(_)
| ArrowError::CsvError(_)
| ArrowError::IpcError(_)
| ArrowError::CDataInterface(_)
| ArrowError::ParquetError(_) => FlightDecodeIssue::MalformedData,
_ => FlightDecodeIssue::Other,
}
}

fn classify_status(status: &Status) -> FlightDecodeIssue {
classify_error_message(status.message())
}

fn classify_external_error(error: &(dyn StdError + Send + Sync + 'static)) -> FlightDecodeIssue {
if let Some(arrow_err) = error.downcast_ref::<ArrowError>() {
classify_arrow_error(arrow_err)
} else if let Some(status) = error.downcast_ref::<Status>() {
classify_status(status)
} else if let Some(io_error) = error.downcast_ref::<std::io::Error>() {
classify_error_message(&io_error.to_string())
} else {
classify_error_message(&error.to_string())
}
}

fn classify_error_message(message: &str) -> FlightDecodeIssue {
if is_transport_error_message(message) {
FlightDecodeIssue::TransportInterrupted
} else {
FlightDecodeIssue::ServerStatus(message.to_string())
}
}

pub fn is_transport_error_message(message: &str) -> bool {
let lower = message.to_ascii_lowercase();
TRANSPORT_ERROR_SNIPPETS
.iter()
.any(|snippet| lower.contains(snippet))
}
pub fn error_kind(message: &str) -> &str {
let message = message.to_ascii_lowercase();
if message.contains("timeout") || message.contains("timedout") {
Expand All @@ -418,3 +519,70 @@ pub fn error_kind(message: &str) -> &str {
"Other"
}
}

#[cfg(test)]
mod tests {
use tonic::Code;

use super::*;

#[test]
fn transport_error_returns_interrupt_hint() {
let err = handle_flight_decode_error(
"test_udf",
FlightError::Tonic(Box::new(Status::new(
Code::Internal,
"h2 protocol error: error reading a body from connection",
))),
);
let message = err.message();
assert!(
message.contains("stopped responding before it finished"),
"unexpected transport hint: {message}"
);
}

#[test]
fn server_status_is_preserved() {
let err = handle_flight_decode_error(
"test_udf",
FlightError::Tonic(Box::new(Status::new(
Code::Internal,
"remote handler returned validation error",
))),
);
let message = err.message();
assert!(
message.contains("reported an error: remote handler returned validation error"),
"unexpected server status message: {message}"
);
}

#[test]
fn schema_mismatch_detected() {
let err = handle_flight_decode_error(
"test_udf",
FlightError::Arrow(ArrowError::SchemaError(
"expected Int32, got Utf8".to_string(),
)),
);
let message = err.message();
assert!(
message.contains("returned an unexpected schema"),
"schema mismatch hint missing: {message}"
);
}

#[test]
fn malformed_data_reported() {
let err = handle_flight_decode_error(
"test_udf",
FlightError::Arrow(ArrowError::ParseError("bad payload".to_string())),
);
let message = err.message();
assert!(
message.contains("could not parse"),
"malformed data hint missing: {message}"
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use databend_common_catalog::table_context::TableContext;
use databend_common_exception::ErrorCode;
use databend_common_exception::Result;
use databend_common_expression::udf_client::error_kind;
use databend_common_expression::udf_client::is_transport_error_message;
use databend_common_expression::udf_client::UDFFlightClient;
use databend_common_expression::BlockEntry;
use databend_common_expression::ColumnBuilder;
Expand Down Expand Up @@ -156,7 +157,7 @@ fn retry_on(err: &databend_common_exception::ErrorCode) -> bool {
if err.code() == ErrorCode::U_D_F_DATA_ERROR {
let message = err.message();
// this means the server can't handle the request in 60s
if message.contains("h2 protocol error") {
if is_transport_error_message(&message) {
return false;
}
}
Expand Down
1 change: 1 addition & 0 deletions src/query/service/tests/it/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
mod executor;
mod filter;
mod transforms;
mod udf_transport;
Loading
Loading