Skip to content

Commit

Permalink
feat: support local variable substitution, add __DATABASE__ variabl…
Browse files Browse the repository at this point in the history
…e for bin (#253)

* support local variable substitution

Signed-off-by: Richard Chien <[email protected]>

* update changelog

Signed-off-by: Richard Chien <[email protected]>

---------

Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored Feb 14, 2025
1 parent 89d0d3c commit 2003b1a
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 48 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

* runner: Add `Runner::set_var` method to allow adding runner-local variables for substitution.
* bin: Add `__DATABASE__` variable for accessing current database name from SLT files.

## [0.27.0] - 2025-02-11

* runner: add `shutdown` method to `DB` and `AsyncDB` trait to allow for graceful shutdown of the database connection. Users are encouraged to call `Runner::shutdown` or `Runner::shutdown_async` after running tests to ensure that the database connections are properly closed.
Expand Down
4 changes: 4 additions & 0 deletions sqllogictest-bin/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use itertools::Itertools;
use quick_junit::{NonSuccessKind, Report, TestCase, TestCaseStatus, TestSuite};
use rand::distributions::DistString;
use rand::seq::SliceRandom;
use sqllogictest::substitution::well_known;
use sqllogictest::{
default_column_validator, default_normalizer, default_validator, update_record_with_output,
AsyncDB, Injected, MakeConnection, Record, Runner,
Expand Down Expand Up @@ -466,6 +467,7 @@ async fn run_serial(
for label in labels {
runner.add_label(label);
}
runner.set_var(well_known::DATABASE.to_owned(), config.db.clone());

let filename = file.to_string_lossy().to_string();
let test_case_name = filename.replace(['/', ' ', '.', '-'], "_");
Expand Down Expand Up @@ -539,6 +541,7 @@ async fn update_test_files(
) -> Result<()> {
for file in files {
let mut runner = Runner::new(|| engines::connect(engine, &config));
runner.set_var(well_known::DATABASE.to_owned(), config.db.clone());

if let Err(e) = update_test_file(&mut std::io::stdout(), &mut runner, &file, format).await {
{
Expand Down Expand Up @@ -568,6 +571,7 @@ async fn connect_and_run_test_file(
for label in labels {
runner.add_label(label);
}
runner.set_var(well_known::DATABASE.to_owned(), config.db.clone());
let result = run_test_file(out, &mut runner, filename).await;
runner.shutdown_async().await;

Expand Down
3 changes: 1 addition & 2 deletions sqllogictest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ pub mod connection;
pub mod harness;
pub mod parser;
pub mod runner;
pub mod substitution;

pub use self::column_type::*;
pub use self::connection::*;
pub use self::parser::*;
pub use self::runner::*;

mod substitution;
65 changes: 52 additions & 13 deletions sqllogictest/src/runner.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
//! Sqllogictest runner.
use std::collections::HashSet;
use std::collections::{BTreeMap, HashSet};
use std::fmt::{Debug, Display};
use std::path::Path;
use std::process::{Command, ExitStatus, Output};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use std::vec;

Expand All @@ -16,6 +16,7 @@ use md5::Digest;
use owo_colors::OwoColorize;
use rand::Rng;
use similar::{Change, ChangeTag, TextDiff};
use tempfile::TempDir;

use crate::parser::*;
use crate::substitution::Substitution;
Expand Down Expand Up @@ -521,6 +522,36 @@ pub fn strict_column_validator<T: ColumnType>(actual: &Vec<T>, expected: &Vec<T>
.any(|(actual_column, expected_column)| actual_column != expected_column)
}

#[derive(Default)]
pub(crate) struct RunnerLocals {
/// The temporary directory. Test cases can use `__TEST_DIR__` to refer to this directory.
/// Lazily initialized and cleaned up when dropped.
test_dir: OnceLock<TempDir>,
/// Runtime variables for substitution.
variables: BTreeMap<String, String>,
}

impl RunnerLocals {
pub fn test_dir(&self) -> String {
let test_dir = self
.test_dir
.get_or_init(|| TempDir::new().expect("failed to create testdir"));
test_dir.path().to_string_lossy().into_owned()
}

fn set_var(&mut self, key: String, value: String) {
self.variables.insert(key, value);
}

pub fn get_var(&self, key: &str) -> Option<&String> {
self.variables.get(key)
}

pub fn vars(&self) -> &BTreeMap<String, String> {
&self.variables
}
}

/// Sqllogictest runner.
pub struct Runner<D: AsyncDB, M: MakeConnection<Conn = D>> {
conn: Connections<D, M>,
Expand All @@ -529,13 +560,15 @@ pub struct Runner<D: AsyncDB, M: MakeConnection<Conn = D>> {
// normalizer is used to normalize the result text
normalizer: Normalizer,
column_type_validator: ColumnTypeValidator<D::ColumnType>,
substitution: Option<Substitution>,
substitution_on: bool,
sort_mode: Option<SortMode>,
result_mode: Option<ResultMode>,
/// 0 means never hashing
hash_threshold: usize,
/// Labels for condition `skipif` and `onlyif`.
labels: HashSet<String>,
/// Local variables/context for the runner.
locals: RunnerLocals,
}

impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
Expand All @@ -547,12 +580,13 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
validator: default_validator,
normalizer: default_normalizer,
column_type_validator: default_column_validator,
substitution: None,
substitution_on: false,
sort_mode: None,
result_mode: None,
hash_threshold: 0,
labels: HashSet::new(),
conn: Connections::new(make_conn),
locals: RunnerLocals::default(),
}
}

Expand All @@ -561,6 +595,11 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
self.labels.insert(label.to_string());
}

/// Set a local variable for substitution.
pub fn set_var(&mut self, key: String, value: String) {
self.locals.set_var(key, value);
}

pub fn with_normalizer(&mut self, normalizer: Normalizer) {
self.normalizer = normalizer;
}
Expand Down Expand Up @@ -862,11 +901,7 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
Control::ResultMode(result_mode) => {
self.result_mode = Some(result_mode);
}
Control::Substitution(on_off) => match (&mut self.substitution, on_off) {
(s @ None, true) => *s = Some(Substitution::default()),
(s @ Some(_), false) => *s = None,
_ => {}
},
Control::Substitution(on_off) => self.substitution_on = on_off,
}

RecordOutput::Nothing
Expand Down Expand Up @@ -1260,18 +1295,22 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
.expect("create db failed");
let target = hosts[idx % hosts.len()].clone();

let mut locals = RunnerLocals::default();
locals.set_var("__DATABASE__".to_owned(), db_name.clone());

let mut tester = Runner {
conn: Connections::new(move || {
conn_builder(target.clone(), db_name.clone()).map(Ok)
}),
validator: self.validator,
normalizer: self.normalizer,
column_type_validator: self.column_type_validator,
substitution: self.substitution.clone(),
substitution_on: self.substitution_on,
sort_mode: self.sort_mode,
result_mode: self.result_mode,
hash_threshold: self.hash_threshold,
labels: self.labels.clone(),
locals,
};

tasks.push(async move {
Expand Down Expand Up @@ -1317,9 +1356,9 @@ impl<D: AsyncDB, M: MakeConnection<Conn = D>> Runner<D, M> {
/// This is useful for `system` commands: The shell can do the environment variables, and we can
/// write strings like `\n` without escaping.
fn may_substitute(&self, input: String, subst_env_vars: bool) -> Result<String, AnyError> {
if let Some(substitution) = &self.substitution {
substitution
.substitute(&input, subst_env_vars)
if self.substitution_on {
Substitution::new(&self.locals, subst_env_vars)
.substitute(&input)
.map_err(|e| Arc::new(e) as AnyError)
} else {
Ok(input)
Expand Down
84 changes: 52 additions & 32 deletions sqllogictest/src/substitution.rs
Original file line number Diff line number Diff line change
@@ -1,55 +1,75 @@
use std::sync::{Arc, OnceLock};

use subst::Env;
use tempfile::{tempdir, TempDir};

use crate::RunnerLocals;

pub mod well_known {
pub const TEST_DIR: &str = "__TEST_DIR__";
pub const NOW: &str = "__NOW__";
pub const DATABASE: &str = "__DATABASE__";
}

/// Substitute environment variables and special variables like `__TEST_DIR__` in SQL.
#[derive(Default, Clone)]
pub(crate) struct Substitution {
/// The temporary directory for `__TEST_DIR__`.
/// Lazily initialized and cleaned up when dropped.
test_dir: Arc<OnceLock<TempDir>>,
pub(crate) struct Substitution<'a> {
runner_locals: &'a RunnerLocals,
subst_env_vars: bool,
}

impl Substitution<'_> {
pub fn new(runner_locals: &RunnerLocals, subst_env_vars: bool) -> Substitution {
Substitution {
runner_locals,
subst_env_vars,
}
}
}

#[derive(thiserror::Error, Debug)]
#[error("substitution failed: {0}")]
pub(crate) struct SubstError(subst::Error);

impl Substitution {
pub fn substitute(&self, input: &str, subst_env_vars: bool) -> Result<String, SubstError> {
if !subst_env_vars {
Ok(input
.replace("$__TEST_DIR__", &self.test_dir())
.replace("$__NOW__", &self.now()))
} else {
fn now_string() -> String {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("failed to get current time")
.as_nanos()
.to_string()
}

impl Substitution<'_> {
pub fn substitute(&self, input: &str) -> Result<String, SubstError> {
if self.subst_env_vars {
subst::substitute(input, self).map_err(SubstError)
} else {
Ok(self.simple_replace(input))
}
}

fn test_dir(&self) -> String {
let test_dir = self
.test_dir
.get_or_init(|| tempdir().expect("failed to create testdir"));
test_dir.path().to_string_lossy().into_owned()
}

fn now(&self) -> String {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("failed to get current time")
.as_nanos()
.to_string()
fn simple_replace(&self, input: &str) -> String {
let mut res = input
.replace(
&format!("${}", well_known::TEST_DIR),
&self.runner_locals.test_dir(),
)
.replace(&format!("${}", well_known::NOW), &now_string());
for (key, value) in self.runner_locals.vars() {
res = res.replace(&format!("${}", key), value);
}
res
}
}

impl<'a> subst::VariableMap<'a> for Substitution {
impl<'a> subst::VariableMap<'a> for Substitution<'a> {
type Value = String;

fn get(&'a self, key: &str) -> Option<Self::Value> {
match key {
"__TEST_DIR__" => self.test_dir().into(),
"__NOW__" => self.now().into(),
key => Env.get(key),
well_known::TEST_DIR => self.runner_locals.test_dir().into(),
well_known::NOW => now_string().into(),
key => self
.runner_locals
.get_var(key)
.cloned()
.or_else(|| Env.get(key)),
}
}
}
4 changes: 4 additions & 0 deletions tests/substitution/basic.slt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ path $__TEST_DIR__
statement ok
time $__NOW__

# a local variable set before running tester
statement ok
check $__DATABASE__

# non existent variables without default values are errors
statement error No such variable
check $NONEXISTENT_VARIABLE
Expand Down
3 changes: 2 additions & 1 deletion tests/substitution/substitution.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use rusty_fork::rusty_fork_test;
use sqllogictest::{DBOutput, DefaultColumnType};
use sqllogictest::{substitution::well_known, DBOutput, DefaultColumnType};

pub struct FakeDB;

Expand Down Expand Up @@ -59,6 +59,7 @@ rusty_fork_test! {
std::env::set_var("MY_PASSWORD", "rust");

let mut tester = sqllogictest::Runner::new(|| async { Ok(FakeDB) });
tester.set_var(well_known::DATABASE.to_owned(), "fake_db".to_owned());

tester.run_file("./substitution/basic.slt").unwrap();
}
Expand Down

0 comments on commit 2003b1a

Please sign in to comment.