Skip to content

Commit

Permalink
feat: implement statement/execution timeout session variable (#4792)
Browse files Browse the repository at this point in the history
* support set and show on statement/execution timeout session variables.

* implement statement timeout for mysql read, and postgres queries

* add mysql test with max execution time
  • Loading branch information
lyang24 authored Nov 15, 2024
1 parent 42bf7e9 commit cdba7b4
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/common/recordbatch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pin-project.workspace = true
serde.workspace = true
serde_json.workspace = true
snafu.workspace = true
tokio.workspace = true

[dev-dependencies]
tokio.workspace = true
9 changes: 9 additions & 0 deletions src/common/recordbatch/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Stream timeout"))]
StreamTimeout {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: tokio::time::error::Elapsed,
},
}

impl ErrorExt for Error {
Expand Down Expand Up @@ -190,6 +197,8 @@ impl ErrorExt for Error {
Error::SchemaConversion { source, .. } | Error::CastVector { source, .. } => {
source.status_code()
}

Error::StreamTimeout { .. } => StatusCode::Cancelled,
}
}

Expand Down
1 change: 1 addition & 0 deletions src/operator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ workspace = true

[dependencies]
api.workspace = true
async-stream.workspace = true
async-trait = "0.1"
catalog.workspace = true
chrono.workspace = true
Expand Down
10 changes: 10 additions & 0 deletions src/operator/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use datafusion::parquet;
use datatypes::arrow::error::ArrowError;
use snafu::{Location, Snafu};
use table::metadata::TableType;
use tokio::time::error::Elapsed;

#[derive(Snafu)]
#[snafu(visibility(pub))]
Expand Down Expand Up @@ -777,6 +778,14 @@ pub enum Error {
location: Location,
json: String,
},

#[snafu(display("Canceling statement due to statement timeout"))]
StatementTimeout {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: Elapsed,
},
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -924,6 +933,7 @@ impl ErrorExt for Error {
Error::BuildRecordBatch { source, .. } => source.status_code(),

Error::UpgradeCatalogManagerRef { .. } => StatusCode::Internal,
Error::StatementTimeout { .. } => StatusCode::Cancelled,
}
}

Expand Down
93 changes: 88 additions & 5 deletions src/operator/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ mod show;
mod tql;

use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;

use async_stream::stream;
use catalog::kvbackend::KvBackendCatalogManager;
use catalog::CatalogManagerRef;
use client::RecordBatches;
use client::{OutputData, RecordBatches};
use common_error::ext::BoxedError;
use common_meta::cache::TableRouteCacheRef;
use common_meta::cache_invalidator::CacheInvalidatorRef;
Expand All @@ -39,15 +42,19 @@ use common_meta::key::view_info::{ViewInfoManager, ViewInfoManagerRef};
use common_meta::key::{TableMetadataManager, TableMetadataManagerRef};
use common_meta::kv_backend::KvBackendRef;
use common_query::Output;
use common_recordbatch::error::StreamTimeoutSnafu;
use common_recordbatch::RecordBatchStreamWrapper;
use common_telemetry::tracing;
use common_time::range::TimestampRange;
use common_time::Timestamp;
use datafusion_expr::LogicalPlan;
use futures::stream::{Stream, StreamExt};
use partition::manager::{PartitionRuleManager, PartitionRuleManagerRef};
use query::parser::QueryStatement;
use query::QueryEngineRef;
use session::context::{Channel, QueryContextRef};
use session::table_name::table_idents_to_full_name;
use set::set_query_timeout;
use snafu::{ensure, OptionExt, ResultExt};
use sql::statements::copy::{CopyDatabase, CopyDatabaseArgument, CopyTable, CopyTableArgument};
use sql::statements::set_variables::SetVariables;
Expand All @@ -63,8 +70,8 @@ use table::TableRef;
use self::set::{set_bytea_output, set_datestyle, set_timezone, validate_client_encoding};
use crate::error::{
self, CatalogSnafu, ExecLogicalPlanSnafu, ExternalSnafu, InvalidSqlSnafu, NotSupportedSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, TableMetadataManagerSnafu, TableNotFoundSnafu,
UpgradeCatalogManagerRefSnafu,
PlanStatementSnafu, Result, SchemaNotFoundSnafu, StatementTimeoutSnafu,
TableMetadataManagerSnafu, TableNotFoundSnafu, UpgradeCatalogManagerRefSnafu,
};
use crate::insert::InserterRef;
use crate::statement::copy_database::{COPY_DATABASE_TIME_END_KEY, COPY_DATABASE_TIME_START_KEY};
Expand Down Expand Up @@ -338,6 +345,28 @@ impl StatementExecutor {
"DATESTYLE" => set_datestyle(set_var.value, query_ctx)?,

"CLIENT_ENCODING" => validate_client_encoding(set_var)?,
"MAX_EXECUTION_TIME" => match query_ctx.channel() {
Channel::Mysql => set_query_timeout(set_var.value, query_ctx)?,
Channel::Postgres => {
query_ctx.set_warning(format!("Unsupported set variable {}", var_name))
}
_ => {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail()
}
},
"STATEMENT_TIMEOUT" => {
if query_ctx.channel() == Channel::Postgres {
set_query_timeout(set_var.value, query_ctx)?
} else {
return NotSupportedSnafu {
feat: format!("Unsupported set variable {}", var_name),
}
.fail();
}
}
_ => {
// for postgres, we give unknown SET statements a warning with
// success, this is prevent the SET call becoming a blocker
Expand Down Expand Up @@ -387,8 +416,19 @@ impl StatementExecutor {

#[tracing::instrument(skip_all)]
async fn plan_exec(&self, stmt: QueryStatement, query_ctx: QueryContextRef) -> Result<Output> {
let plan = self.plan(&stmt, query_ctx.clone()).await?;
self.exec_plan(plan, query_ctx).await
let timeout = derive_timeout(&stmt, &query_ctx);
match timeout {
Some(timeout) => {
let start = tokio::time::Instant::now();
let output = tokio::time::timeout(timeout, self.plan_exec_inner(stmt, query_ctx))
.await
.context(StatementTimeoutSnafu)?;
// compute remaining timeout
let remaining_timeout = timeout.checked_sub(start.elapsed()).unwrap_or_default();
Ok(attach_timeout(output?, remaining_timeout))
}
None => self.plan_exec_inner(stmt, query_ctx).await,
}
}

async fn get_table(&self, table_ref: &TableReference<'_>) -> Result<TableRef> {
Expand All @@ -405,6 +445,49 @@ impl StatementExecutor {
table_name: table_ref.to_string(),
})
}

async fn plan_exec_inner(
&self,
stmt: QueryStatement,
query_ctx: QueryContextRef,
) -> Result<Output> {
let plan = self.plan(&stmt, query_ctx.clone()).await?;
self.exec_plan(plan, query_ctx).await
}
}

fn attach_timeout(output: Output, mut timeout: Duration) -> Output {
match output.data {
OutputData::AffectedRows(_) | OutputData::RecordBatches(_) => output,
OutputData::Stream(mut stream) => {
let schema = stream.schema();
let s = Box::pin(stream! {
let start = tokio::time::Instant::now();
while let Some(item) = tokio::time::timeout(timeout, stream.next()).await.context(StreamTimeoutSnafu)? {
yield item;
timeout = timeout.checked_sub(tokio::time::Instant::now() - start).unwrap_or(Duration::ZERO);
}
}) as Pin<Box<dyn Stream<Item = _> + Send>>;
let stream = RecordBatchStreamWrapper {
schema,
stream: s,
output_ordering: None,
metrics: Default::default(),
};
Output::new(OutputData::Stream(Box::pin(stream)), output.meta)
}
}
}

/// If the relevant variables are set, the timeout is enforced for all PostgreSQL statements.
/// For MySQL, it applies only to read-only statements.
fn derive_timeout(stmt: &QueryStatement, query_ctx: &QueryContextRef) -> Option<Duration> {
let query_timeout = query_ctx.query_timeout()?;
match (query_ctx.channel(), stmt) {
(Channel::Mysql, QueryStatement::Sql(Statement::Query(_)))
| (Channel::Postgres, QueryStatement::Sql(_)) => Some(query_timeout),
(_, _) => None,
}
}

fn to_copy_table_request(stmt: CopyTable, query_ctx: QueryContextRef) -> Result<CopyTableRequest> {
Expand Down
107 changes: 107 additions & 0 deletions src/operator/src/statement/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::time::Duration;

use common_time::Timezone;
use lazy_static::lazy_static;
use regex::Regex;
use session::context::Channel::Postgres;
use session::context::QueryContextRef;
use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle};
use snafu::{ensure, OptionExt, ResultExt};
Expand All @@ -21,6 +26,15 @@ use sql::statements::set_variables::SetVariables;

use crate::error::{InvalidConfigValueSnafu, InvalidSqlSnafu, NotSupportedSnafu, Result};

lazy_static! {
// Regex rules:
// The string must start with a number (one or more digits).
// The number must be followed by one of the valid time units (ms, s, min, h, d).
// The string must end immediately after the unit, meaning there can be no extra
// characters or spaces after the valid time specification.
static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap();
}

pub fn set_timezone(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let tz_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timezone find in set variable statement",
Expand Down Expand Up @@ -177,3 +191,96 @@ pub fn set_datestyle(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
.set_pg_datetime_style(style.unwrap_or(old_style), order.unwrap_or(older_order));
Ok(())
}

pub fn set_query_timeout(exprs: Vec<Expr>, ctx: QueryContextRef) -> Result<()> {
let timeout_expr = exprs.first().context(NotSupportedSnafu {
feat: "No timeout value find in set query timeout statement",
})?;
match timeout_expr {
Expr::Value(Value::Number(timeout, _)) => {
match timeout.parse::<u64>() {
Ok(timeout) => ctx.set_query_timeout(Duration::from_millis(timeout)),
Err(_) => {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail()
}
}
Ok(())
}
// postgres support time units i.e. SET STATEMENT_TIMEOUT = '50ms';
Expr::Value(Value::SingleQuotedString(timeout))
| Expr::Value(Value::DoubleQuotedString(timeout)) => {
if ctx.channel() != Postgres {
return NotSupportedSnafu {
feat: format!("Invalid timeout expr {} in set variable statement", timeout),
}
.fail();
}
let timeout = parse_pg_query_timeout_input(timeout)?;
ctx.set_query_timeout(Duration::from_millis(timeout));
Ok(())
}
expr => NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
expr
),
}
.fail(),
}
}

// support time units in ms, s, min, h, d for postgres protocol.
// https://www.postgresql.org/docs/8.4/config-setting.html#:~:text=Valid%20memory%20units%20are%20kB,%2C%20and%20d%20(days).
fn parse_pg_query_timeout_input(input: &str) -> Result<u64> {
match input.parse::<u64>() {
Ok(timeout) => Ok(timeout),
Err(_) => {
if let Some(captures) = PG_TIME_INPUT_REGEX.captures(input) {
let value = captures[1].parse::<u64>().expect("regex failed");
let unit = &captures[2];

match unit {
"ms" => Ok(value),
"s" => Ok(value * 1000),
"min" => Ok(value * 60 * 1000),
"h" => Ok(value * 60 * 60 * 1000),
"d" => Ok(value * 24 * 60 * 60 * 1000),
_ => unreachable!("regex failed"),
}
} else {
NotSupportedSnafu {
feat: format!(
"Unsupported timeout expr {} in set variable statement",
input
),
}
.fail()
}
}
}
}

#[cfg(test)]
mod test {
use crate::statement::set::parse_pg_query_timeout_input;

#[test]
fn test_parse_pg_query_timeout_input() {
assert!(parse_pg_query_timeout_input("").is_err());
assert!(parse_pg_query_timeout_input(" 50 ms").is_err());
assert!(parse_pg_query_timeout_input("5s 1ms").is_err());
assert!(parse_pg_query_timeout_input("3a").is_err());
assert!(parse_pg_query_timeout_input("1.5min").is_err());
assert!(parse_pg_query_timeout_input("ms").is_err());
assert!(parse_pg_query_timeout_input("a").is_err());
assert!(parse_pg_query_timeout_input("-1").is_err());

assert_eq!(50, parse_pg_query_timeout_input("50").unwrap());
assert_eq!(12, parse_pg_query_timeout_input("12ms").unwrap());
assert_eq!(2000, parse_pg_query_timeout_input("2s").unwrap());
assert_eq!(60000, parse_pg_query_timeout_input("1min").unwrap());
}
}
19 changes: 18 additions & 1 deletion src/query/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ use datatypes::vectors::StringVector;
use object_store::ObjectStore;
use once_cell::sync::Lazy;
use regex::Regex;
use session::context::QueryContextRef;
use session::context::{Channel, QueryContextRef};
pub use show_create_table::create_table_stmt;
use snafu::{ensure, OptionExt, ResultExt};
use sql::ast::Ident;
Expand Down Expand Up @@ -651,6 +651,23 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result<
let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style();
format!("{}, {}", style, order)
}
"MAX_EXECUTION_TIME" => {
if query_ctx.channel() == Channel::Mysql {
query_ctx.query_timeout_as_millis().to_string()
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
"STATEMENT_TIMEOUT" => {
// Add time units to postgres query timeout display.
if query_ctx.channel() == Channel::Postgres {
let mut timeout = query_ctx.query_timeout_as_millis().to_string();
timeout.push_str("ms");
timeout
} else {
return UnsupportedVariableSnafu { name: variable }.fail();
}
}
_ => return UnsupportedVariableSnafu { name: variable }.fail(),
};
let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
Expand Down
Loading

0 comments on commit cdba7b4

Please sign in to comment.