diff --git a/CMakeLists.txt b/CMakeLists.txt index eafcec5..596ea6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,11 +9,12 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension) project(${TARGET_NAME}) include_directories(src/include) -set(EXTENSION_SOURCES +set(EXTENSION_SOURCES src/parser_tools_extension.cpp src/parse_tables.cpp src/parse_where.cpp src/parse_functions.cpp + src/parse_statements.cpp ) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) diff --git a/README.md b/README.md index 02b8e04..dfc54da 100644 --- a/README.md +++ b/README.md @@ -4,13 +4,14 @@ An experimental DuckDB extension that exposes functionality from DuckDB's native ## Overview -`parser_tools` is a DuckDB extension designed to provide SQL parsing capabilities within the database. It allows you to analyze SQL queries and extract structural information directly in SQL. This extension provides parsing functions for tables, WHERE clauses, and function calls (see [Functions](#functions) below). +`parser_tools` is a DuckDB extension designed to provide SQL parsing capabilities within the database. It allows you to analyze SQL queries and extract structural information directly in SQL. This extension provides parsing functions for tables, WHERE clauses, function calls, and statements. ## Features - **Extract table references** from a SQL query with context information (e.g. `FROM`, `JOIN`, etc.) - **Extract function calls** from a SQL query with context information (e.g. `SELECT`, `WHERE`, `HAVING`, etc.) - **Parse WHERE clauses** to extract conditions and operators +- **Parse multi-statement SQL** to extract individual statements or count the number of statements - Support for **window functions**, **nested functions**, and **CTEs** - Includes **schema**, **name**, and **context** information for all extractions - Built on DuckDB's native SQL parser @@ -94,7 +95,7 @@ Context helps identify where elements are used in the query. ## Functions -This extension provides parsing functions for tables, functions, and WHERE clauses. Each category includes both table functions (for detailed results) and scalar functions (for programmatic use). +This extension provides parsing functions for tables, functions, WHERE clauses, and statements. Each category includes both table functions (for detailed results) and scalar functions (for programmatic use). In general, errors (e.g. Parse Exception) will not be exposed to the user, but instead will result in an empty result. This simplifies batch processing. When validity is needed, [is_parsable](#is_parsablesql_query--scalar-function) can be used. @@ -319,6 +320,92 @@ FROM (VALUES └───────────────────────────────────────────────┴────────┘ ``` +--- + +### Statement Parsing Functions + +These functions parse multi-statement SQL strings and extract individual statements or count them. + +#### `parse_statements(sql_query)` – Table Function + +Parses a SQL string containing multiple statements and returns each statement as a separate row. + +##### Usage +```sql +SELECT * FROM parse_statements('SELECT 42; SELECT 43;'); +``` + +##### Returns +A table with: +- `statement`: the SQL statement text + +##### Example +```sql +SELECT * FROM parse_statements($$ + SELECT * FROM users WHERE active = true; + INSERT INTO log VALUES ('query executed'); + SELECT count(*) FROM transactions; +$$); +``` + +| statement | +|-----------| +| SELECT * FROM users WHERE (active = true) | +| INSERT INTO log (VALUES ('query executed')) | +| SELECT count_star() FROM transactions | + +--- + +#### `parse_statements(sql_query)` – Scalar Function + +Returns a list of statement strings from a multi-statement SQL query. + +##### Usage +```sql +SELECT parse_statements('SELECT 42; SELECT 43;'); +---- +[SELECT 42, SELECT 43] +``` + +##### Returns +A list of strings, each being a SQL statement. + +##### Example +```sql +SELECT parse_statements('SELECT 1; INSERT INTO test VALUES (2); SELECT 3;'); +---- +[SELECT 1, 'INSERT INTO test (VALUES (2))', SELECT 3] +``` + +--- + +#### `num_statements(sql_query)` – Scalar Function + +Returns the number of statements in a multi-statement SQL query. + +##### Usage +```sql +SELECT num_statements('SELECT 42; SELECT 43;'); +---- +2 +``` + +##### Returns +An integer count of the number of SQL statements. + +##### Example +```sql +SELECT num_statements($$ + WITH cte AS (SELECT 1) SELECT * FROM cte; + UPDATE users SET last_seen = now(); + SELECT count(*) FROM users; + DELETE FROM temp_data; +$$); +---- +4 +``` + +--- ## Development diff --git a/src/include/parse_statements.hpp b/src/include/parse_statements.hpp new file mode 100644 index 0000000..8d3006b --- /dev/null +++ b/src/include/parse_statements.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include "duckdb.hpp" +#include +#include + +namespace duckdb { + +// Forward declarations +class ExtensionLoader; + +struct StatementResult { + std::string statement; +}; + +void RegisterParseStatementsFunction(ExtensionLoader &loader); +void RegisterParseStatementsScalarFunction(ExtensionLoader &loader); + +} // namespace duckdb \ No newline at end of file diff --git a/src/parse_statements.cpp b/src/parse_statements.cpp new file mode 100644 index 0000000..7085793 --- /dev/null +++ b/src/parse_statements.cpp @@ -0,0 +1,145 @@ +#include "parse_statements.hpp" +#include "duckdb.hpp" +#include "duckdb/parser/parser.hpp" +#include "duckdb/parser/statement/select_statement.hpp" +#include "duckdb/function/scalar/nested_functions.hpp" + +namespace duckdb { + +struct ParseStatementsState : public GlobalTableFunctionState { + idx_t row = 0; + vector results; +}; + +struct ParseStatementsBindData : public TableFunctionData { + string sql; +}; + +// BIND function: runs during query planning to decide output schema +static unique_ptr ParseStatementsBind(ClientContext &context, + TableFunctionBindInput &input, + vector &return_types, + vector &names) { + + string sql_input = StringValue::Get(input.inputs[0]); + + // Return single column with statement text + return_types = {LogicalType::VARCHAR}; + names = {"statement"}; + + // Create a bind data object to hold the SQL input + auto result = make_uniq(); + result->sql = sql_input; + + return std::move(result); +} + +// INIT function: runs before table function execution +static unique_ptr ParseStatementsInit(ClientContext &context, + TableFunctionInitInput &input) { + return make_uniq(); +} + +static void ExtractStatementsFromSQL(const std::string &sql, std::vector &results) { + Parser parser; + + try { + parser.ParseQuery(sql); + } catch (const ParserException &ex) { + // Swallow parser exceptions to make this function more robust + return; + } + + for (auto &stmt : parser.statements) { + if (stmt) { + // Convert statement back to string + auto statement_str = stmt->ToString(); + results.push_back(StatementResult{statement_str}); + } + } +} + +static void ParseStatementsFunction(ClientContext &context, + TableFunctionInput &data, + DataChunk &output) { + auto &state = (ParseStatementsState &)*data.global_state; + auto &bind_data = (ParseStatementsBindData &)*data.bind_data; + + if (state.results.empty() && state.row == 0) { + ExtractStatementsFromSQL(bind_data.sql, state.results); + } + + if (state.row >= state.results.size()) { + return; + } + + auto &stmt = state.results[state.row]; + output.SetCardinality(1); + output.SetValue(0, 0, Value(stmt.statement)); + + state.row++; +} + +static void ParseStatementsScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, args.size(), + [&result](string_t query) -> list_entry_t { + // Parse the SQL query and extract statements + auto query_string = query.GetString(); + std::vector parsed_statements; + ExtractStatementsFromSQL(query_string, parsed_statements); + + auto current_size = ListVector::GetListSize(result); + auto number_of_statements = parsed_statements.size(); + auto new_size = current_size + number_of_statements; + + // Grow list if needed + if (ListVector::GetListCapacity(result) < new_size) { + ListVector::Reserve(result, new_size); + } + + // Write the statements into the child vector + auto statements = FlatVector::GetData(ListVector::GetEntry(result)); + for (size_t i = 0; i < parsed_statements.size(); i++) { + auto &stmt = parsed_statements[i]; + statements[current_size + i] = StringVector::AddStringOrBlob(ListVector::GetEntry(result), stmt.statement); + } + + // Update size + ListVector::SetListSize(result, new_size); + + return list_entry_t(current_size, number_of_statements); + }); +} + +static void NumStatementsScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) { + UnaryExecutor::Execute(args.data[0], result, args.size(), + [](string_t query) -> int64_t { + // Parse the SQL query and count statements + auto query_string = query.GetString(); + std::vector parsed_statements; + ExtractStatementsFromSQL(query_string, parsed_statements); + + return static_cast(parsed_statements.size()); + }); +} + +// Extension scaffolding +// --------------------------------------------------- + +void RegisterParseStatementsFunction(ExtensionLoader &loader) { + // Table function that returns one row per statement + TableFunction tf("parse_statements", {LogicalType::VARCHAR}, ParseStatementsFunction, ParseStatementsBind, ParseStatementsInit); + loader.RegisterFunction(tf); +} + +void RegisterParseStatementsScalarFunction(ExtensionLoader &loader) { + // parse_statements is a scalar function that returns a list of statement strings + ScalarFunction sf("parse_statements", {LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), ParseStatementsScalarFunction); + loader.RegisterFunction(sf); + + // num_statements is a scalar function that returns the count of statements + ScalarFunction num_sf("num_statements", {LogicalType::VARCHAR}, LogicalType::BIGINT, NumStatementsScalarFunction); + loader.RegisterFunction(num_sf); +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/parser_tools_extension.cpp b/src/parser_tools_extension.cpp index a324102..091a690 100644 --- a/src/parser_tools_extension.cpp +++ b/src/parser_tools_extension.cpp @@ -4,6 +4,7 @@ #include "parse_tables.hpp" #include "parse_where.hpp" #include "parse_functions.hpp" +#include "parse_statements.hpp" #include "duckdb.hpp" #include "duckdb/common/exception.hpp" #include "duckdb/common/string_util.hpp" @@ -30,6 +31,8 @@ static void LoadInternal(ExtensionLoader &loader) { RegisterParseWhereDetailedFunction(loader); RegisterParseFunctionsFunction(loader); RegisterParseFunctionScalarFunction(loader); + RegisterParseStatementsFunction(loader); + RegisterParseStatementsScalarFunction(loader); } void ParserToolsExtension::Load(ExtensionLoader &loader) { diff --git a/test/sql/parse_tools/scalar_functions/num_statements.test b/test/sql/parse_tools/scalar_functions/num_statements.test new file mode 100644 index 0000000..1a8f1e6 --- /dev/null +++ b/test/sql/parse_tools/scalar_functions/num_statements.test @@ -0,0 +1,79 @@ +# name: test/sql/parser_tools/scalar_functions/num_statements.test +# description: test num_statements scalar function +# group: [num_statements] + +# Before we load the extension, this will fail +statement error +SELECT num_statements('SELECT 42; SELECT 43;'); +---- +Catalog Error: Scalar Function with name num_statements does not exist! + +# Require statement will ensure this test is run with this extension loaded +require parser_tools + +# Single statement +query I +SELECT num_statements('SELECT 42;'); +---- +1 + +# Two statements +query I +SELECT num_statements('SELECT 42; SELECT 43;'); +---- +2 + +# Three statements +query I +SELECT num_statements('SELECT 1; SELECT 2; SELECT 3;'); +---- +3 + +# Multiple statements with different types +query I +SELECT num_statements('SELECT 1; INSERT INTO test VALUES (2); UPDATE test SET x = 1; SELECT 3;'); +---- +4 + +# Complex multi-statement query +query I +SELECT num_statements($$ + WITH cte AS (SELECT 1 as a) SELECT * FROM cte; + SELECT upper('hello'); + SELECT count(*) FROM users WHERE id > 10; + INSERT INTO log VALUES ('done'); +$$); +---- +4 + +# Single complex statement +query I +SELECT num_statements($$ + WITH cte1 AS (SELECT * FROM table1), + cte2 AS (SELECT * FROM table2) + SELECT cte1.id, cte2.name + FROM cte1 + JOIN cte2 ON cte1.id = cte2.id + WHERE cte1.active = true + ORDER BY cte1.created_at DESC; +$$); +---- +1 + +# Empty input +query I +SELECT num_statements(''); +---- +0 + +# Whitespace only +query I +SELECT num_statements(' '); +---- +0 + +# Invalid SQL should return 0 +query I +SELECT num_statements('INVALID SQL SYNTAX HERE'); +---- +0 \ No newline at end of file diff --git a/test/sql/parse_tools/scalar_functions/parse_statements.test b/test/sql/parse_tools/scalar_functions/parse_statements.test new file mode 100644 index 0000000..bbba3e1 --- /dev/null +++ b/test/sql/parse_tools/scalar_functions/parse_statements.test @@ -0,0 +1,64 @@ +# name: test/sql/parser_tools/scalar_functions/parse_statements.test +# description: test parse_statements scalar function +# group: [parse_statements] + +# Before we load the extension, this will fail +statement error +SELECT parse_statements('SELECT 42; SELECT 43;'); +---- +Catalog Error: Scalar Function with name parse_statements does not exist! + +# Require statement will ensure this test is run with this extension loaded +require parser_tools + +# Single statement +query I +SELECT parse_statements('SELECT 42;'); +---- +[SELECT 42] + +# Multiple statements +query I +SELECT parse_statements('SELECT 42; SELECT 43;'); +---- +[SELECT 42, SELECT 43] + +# Three statements +query I +SELECT parse_statements('SELECT 1; SELECT 2; SELECT 3;'); +---- +[SELECT 1, SELECT 2, SELECT 3] + +# Multiple statements with different types +query I +SELECT parse_statements('SELECT 1; INSERT INTO test VALUES (2); SELECT 3;'); +---- +[SELECT 1, 'INSERT INTO test (VALUES (2))', SELECT 3] + +# Complex multi-statement query +query I +SELECT parse_statements($$ + WITH cte AS (SELECT 1 as a) SELECT * FROM cte; + SELECT upper('hello'); + SELECT count(*) FROM users WHERE id > 10; +$$); +---- +['WITH cte AS (SELECT 1 AS a)SELECT * FROM cte', 'SELECT upper(\'hello\')', 'SELECT count_star() FROM users WHERE (id > 10)'] + +# Empty input +query I +SELECT parse_statements(''); +---- +[] + +# Whitespace only +query I +SELECT parse_statements(' '); +---- +[] + +# Invalid SQL should return empty list +query I +SELECT parse_statements('INVALID SQL SYNTAX HERE'); +---- +[] \ No newline at end of file diff --git a/test/sql/parse_tools/table_functions/parse_statements.test b/test/sql/parse_tools/table_functions/parse_statements.test new file mode 100644 index 0000000..d0041b2 --- /dev/null +++ b/test/sql/parse_tools/table_functions/parse_statements.test @@ -0,0 +1,70 @@ +# name: test/sql/parser_tools/table_functions/parse_statements.test +# description: test parse_statements table function +# group: [parse_statements] + +# Before we load the extension, this will fail +statement error +SELECT * FROM parse_statements('SELECT 42; SELECT 43;'); +---- +Catalog Error: Table Function with name parse_statements does not exist! + +# Require statement will ensure this test is run with this extension loaded +require parser_tools + +# Single statement +query I +SELECT * FROM parse_statements('SELECT 42;'); +---- +SELECT 42 + +# Multiple statements +query I +SELECT * FROM parse_statements('SELECT 42; SELECT 43;'); +---- +SELECT 42 +SELECT 43 + +# Multiple statements with different types +query I +SELECT * FROM parse_statements('SELECT 1; INSERT INTO test VALUES (2); SELECT 3;'); +---- +SELECT 1 +INSERT INTO test (VALUES (2)) +SELECT 3 + +# Complex multi-statement query +query I +SELECT * FROM parse_statements($$ + WITH cte AS (SELECT 1 as a) SELECT * FROM cte; + SELECT upper('hello'); + SELECT count(*) FROM users WHERE id > 10; +$$); +---- +WITH cte AS (SELECT 1 AS a)SELECT * FROM cte +SELECT upper('hello') +SELECT count_star() FROM users WHERE (id > 10) + +# Statements with CTEs and joins +query I +SELECT * FROM parse_statements($$ + SELECT a.id FROM table_a a JOIN table_b b ON a.id = b.id; + WITH data AS (SELECT * FROM source) SELECT count(*) FROM data; +$$); +---- +SELECT a.id FROM table_a AS a INNER JOIN table_b AS b ON ((a.id = b.id)) +WITH "data" AS (SELECT * FROM "source")SELECT count_star() FROM "data" + +# Empty input +query I +SELECT * FROM parse_statements(''); +---- + +# Whitespace only +query I +SELECT * FROM parse_statements(' '); +---- + +# Invalid SQL should return no results +query I +SELECT * FROM parse_statements('INVALID SQL SYNTAX HERE'); +---- \ No newline at end of file