Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension)
project(${TARGET_NAME})
include_directories(src/include)

set(EXTENSION_SOURCES
set(EXTENSION_SOURCES
src/parser_tools_extension.cpp
src/parse_tables.cpp
src/parse_where.cpp
src/parse_functions.cpp
src/parse_statements.cpp
)

build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES})
Expand Down
91 changes: 89 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ An experimental DuckDB extension that exposes functionality from DuckDB's native

## Overview

`parser_tools` is a DuckDB extension designed to provide SQL parsing capabilities within the database. It allows you to analyze SQL queries and extract structural information directly in SQL. This extension provides parsing functions for tables, WHERE clauses, and function calls (see [Functions](#functions) below).
`parser_tools` is a DuckDB extension designed to provide SQL parsing capabilities within the database. It allows you to analyze SQL queries and extract structural information directly in SQL. This extension provides parsing functions for tables, WHERE clauses, function calls, and statements.

## Features

- **Extract table references** from a SQL query with context information (e.g. `FROM`, `JOIN`, etc.)
- **Extract function calls** from a SQL query with context information (e.g. `SELECT`, `WHERE`, `HAVING`, etc.)
- **Parse WHERE clauses** to extract conditions and operators
- **Parse multi-statement SQL** to extract individual statements or count the number of statements
- Support for **window functions**, **nested functions**, and **CTEs**
- Includes **schema**, **name**, and **context** information for all extractions
- Built on DuckDB's native SQL parser
Expand Down Expand Up @@ -94,7 +95,7 @@ Context helps identify where elements are used in the query.

## Functions

This extension provides parsing functions for tables, functions, and WHERE clauses. Each category includes both table functions (for detailed results) and scalar functions (for programmatic use).
This extension provides parsing functions for tables, functions, WHERE clauses, and statements. Each category includes both table functions (for detailed results) and scalar functions (for programmatic use).

In general, errors (e.g. Parse Exception) will not be exposed to the user, but instead will result in an empty result. This simplifies batch processing. When validity is needed, [is_parsable](#is_parsablesql_query--scalar-function) can be used.

Expand Down Expand Up @@ -319,6 +320,92 @@ FROM (VALUES
└───────────────────────────────────────────────┴────────┘
```

---

### Statement Parsing Functions

These functions parse multi-statement SQL strings and extract individual statements or count them.

#### `parse_statements(sql_query)` – Table Function

Parses a SQL string containing multiple statements and returns each statement as a separate row.

##### Usage
```sql
SELECT * FROM parse_statements('SELECT 42; SELECT 43;');
```

##### Returns
A table with:
- `statement`: the SQL statement text

##### Example
```sql
SELECT * FROM parse_statements($$
SELECT * FROM users WHERE active = true;
INSERT INTO log VALUES ('query executed');
SELECT count(*) FROM transactions;
$$);
```

| statement |
|-----------|
| SELECT * FROM users WHERE (active = true) |
| INSERT INTO log (VALUES ('query executed')) |
| SELECT count_star() FROM transactions |

---

#### `parse_statements(sql_query)` – Scalar Function

Returns a list of statement strings from a multi-statement SQL query.

##### Usage
```sql
SELECT parse_statements('SELECT 42; SELECT 43;');
----
[SELECT 42, SELECT 43]
```

##### Returns
A list of strings, each being a SQL statement.

##### Example
```sql
SELECT parse_statements('SELECT 1; INSERT INTO test VALUES (2); SELECT 3;');
----
[SELECT 1, 'INSERT INTO test (VALUES (2))', SELECT 3]
```

---

#### `num_statements(sql_query)` – Scalar Function

Returns the number of statements in a multi-statement SQL query.

##### Usage
```sql
SELECT num_statements('SELECT 42; SELECT 43;');
----
2
```

##### Returns
An integer count of the number of SQL statements.

##### Example
```sql
SELECT num_statements($$
WITH cte AS (SELECT 1) SELECT * FROM cte;
UPDATE users SET last_seen = now();
SELECT count(*) FROM users;
DELETE FROM temp_data;
$$);
----
4
```

---

## Development

Expand Down
19 changes: 19 additions & 0 deletions src/include/parse_statements.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "duckdb.hpp"
#include <string>
#include <vector>

namespace duckdb {

// Forward declarations
class ExtensionLoader;

struct StatementResult {
std::string statement;
};

void RegisterParseStatementsFunction(ExtensionLoader &loader);
void RegisterParseStatementsScalarFunction(ExtensionLoader &loader);

} // namespace duckdb
145 changes: 145 additions & 0 deletions src/parse_statements.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#include "parse_statements.hpp"
#include "duckdb.hpp"
#include "duckdb/parser/parser.hpp"
#include "duckdb/parser/statement/select_statement.hpp"
#include "duckdb/function/scalar/nested_functions.hpp"

namespace duckdb {

struct ParseStatementsState : public GlobalTableFunctionState {
idx_t row = 0;
vector<StatementResult> results;
};

struct ParseStatementsBindData : public TableFunctionData {
string sql;
};

// BIND function: runs during query planning to decide output schema
static unique_ptr<FunctionData> ParseStatementsBind(ClientContext &context,
TableFunctionBindInput &input,
vector<LogicalType> &return_types,
vector<string> &names) {

string sql_input = StringValue::Get(input.inputs[0]);

// Return single column with statement text
return_types = {LogicalType::VARCHAR};
names = {"statement"};

// Create a bind data object to hold the SQL input
auto result = make_uniq<ParseStatementsBindData>();
result->sql = sql_input;

return std::move(result);
}

// INIT function: runs before table function execution
static unique_ptr<GlobalTableFunctionState> ParseStatementsInit(ClientContext &context,
TableFunctionInitInput &input) {
return make_uniq<ParseStatementsState>();
}

static void ExtractStatementsFromSQL(const std::string &sql, std::vector<StatementResult> &results) {
Parser parser;

try {
parser.ParseQuery(sql);
} catch (const ParserException &ex) {
// Swallow parser exceptions to make this function more robust
return;
}

for (auto &stmt : parser.statements) {
if (stmt) {
// Convert statement back to string
auto statement_str = stmt->ToString();
results.push_back(StatementResult{statement_str});
}
}
}

static void ParseStatementsFunction(ClientContext &context,
TableFunctionInput &data,
DataChunk &output) {
auto &state = (ParseStatementsState &)*data.global_state;
auto &bind_data = (ParseStatementsBindData &)*data.bind_data;

if (state.results.empty() && state.row == 0) {
ExtractStatementsFromSQL(bind_data.sql, state.results);
}

if (state.row >= state.results.size()) {
return;
}

auto &stmt = state.results[state.row];
output.SetCardinality(1);
output.SetValue(0, 0, Value(stmt.statement));

state.row++;
}

static void ParseStatementsScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) {
UnaryExecutor::Execute<string_t, list_entry_t>(args.data[0], result, args.size(),
[&result](string_t query) -> list_entry_t {
// Parse the SQL query and extract statements
auto query_string = query.GetString();
std::vector<StatementResult> parsed_statements;
ExtractStatementsFromSQL(query_string, parsed_statements);

auto current_size = ListVector::GetListSize(result);
auto number_of_statements = parsed_statements.size();
auto new_size = current_size + number_of_statements;

// Grow list if needed
if (ListVector::GetListCapacity(result) < new_size) {
ListVector::Reserve(result, new_size);
}

// Write the statements into the child vector
auto statements = FlatVector::GetData<string_t>(ListVector::GetEntry(result));
for (size_t i = 0; i < parsed_statements.size(); i++) {
auto &stmt = parsed_statements[i];
statements[current_size + i] = StringVector::AddStringOrBlob(ListVector::GetEntry(result), stmt.statement);
}

// Update size
ListVector::SetListSize(result, new_size);

return list_entry_t(current_size, number_of_statements);
});
}

static void NumStatementsScalarFunction(DataChunk &args, ExpressionState &state, Vector &result) {
UnaryExecutor::Execute<string_t, int64_t>(args.data[0], result, args.size(),
[](string_t query) -> int64_t {
// Parse the SQL query and count statements
auto query_string = query.GetString();
std::vector<StatementResult> parsed_statements;
ExtractStatementsFromSQL(query_string, parsed_statements);

return static_cast<int64_t>(parsed_statements.size());
});
}

// Extension scaffolding
// ---------------------------------------------------

void RegisterParseStatementsFunction(ExtensionLoader &loader) {
// Table function that returns one row per statement
TableFunction tf("parse_statements", {LogicalType::VARCHAR}, ParseStatementsFunction, ParseStatementsBind, ParseStatementsInit);
loader.RegisterFunction(tf);
}

void RegisterParseStatementsScalarFunction(ExtensionLoader &loader) {
// parse_statements is a scalar function that returns a list of statement strings
ScalarFunction sf("parse_statements", {LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), ParseStatementsScalarFunction);
loader.RegisterFunction(sf);

// num_statements is a scalar function that returns the count of statements
ScalarFunction num_sf("num_statements", {LogicalType::VARCHAR}, LogicalType::BIGINT, NumStatementsScalarFunction);
loader.RegisterFunction(num_sf);
}

} // namespace duckdb
3 changes: 3 additions & 0 deletions src/parser_tools_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "parse_tables.hpp"
#include "parse_where.hpp"
#include "parse_functions.hpp"
#include "parse_statements.hpp"
#include "duckdb.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/string_util.hpp"
Expand All @@ -30,6 +31,8 @@ static void LoadInternal(ExtensionLoader &loader) {
RegisterParseWhereDetailedFunction(loader);
RegisterParseFunctionsFunction(loader);
RegisterParseFunctionScalarFunction(loader);
RegisterParseStatementsFunction(loader);
RegisterParseStatementsScalarFunction(loader);
}

void ParserToolsExtension::Load(ExtensionLoader &loader) {
Expand Down
79 changes: 79 additions & 0 deletions test/sql/parse_tools/scalar_functions/num_statements.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# name: test/sql/parser_tools/scalar_functions/num_statements.test
# description: test num_statements scalar function
# group: [num_statements]

# Before we load the extension, this will fail
statement error
SELECT num_statements('SELECT 42; SELECT 43;');
----
Catalog Error: Scalar Function with name num_statements does not exist!

# Require statement will ensure this test is run with this extension loaded
require parser_tools

# Single statement
query I
SELECT num_statements('SELECT 42;');
----
1

# Two statements
query I
SELECT num_statements('SELECT 42; SELECT 43;');
----
2

# Three statements
query I
SELECT num_statements('SELECT 1; SELECT 2; SELECT 3;');
----
3

# Multiple statements with different types
query I
SELECT num_statements('SELECT 1; INSERT INTO test VALUES (2); UPDATE test SET x = 1; SELECT 3;');
----
4

# Complex multi-statement query
query I
SELECT num_statements($$
WITH cte AS (SELECT 1 as a) SELECT * FROM cte;
SELECT upper('hello');
SELECT count(*) FROM users WHERE id > 10;
INSERT INTO log VALUES ('done');
$$);
----
4

# Single complex statement
query I
SELECT num_statements($$
WITH cte1 AS (SELECT * FROM table1),
cte2 AS (SELECT * FROM table2)
SELECT cte1.id, cte2.name
FROM cte1
JOIN cte2 ON cte1.id = cte2.id
WHERE cte1.active = true
ORDER BY cte1.created_at DESC;
$$);
----
1

# Empty input
query I
SELECT num_statements('');
----
0

# Whitespace only
query I
SELECT num_statements(' ');
----
0

# Invalid SQL should return 0
query I
SELECT num_statements('INVALID SQL SYNTAX HERE');
----
0
Loading
Loading