diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a6dfd25 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.lock binary \ No newline at end of file diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml new file mode 100644 index 0000000..0a6cac0 --- /dev/null +++ b/.github/workflows/run-tests.yml @@ -0,0 +1,17 @@ +name: Run tests + +on: pull_request + +jobs: + run-tests: + runs-on: ubuntu-latest + steps: + - name: Check out repository + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: direnv-nix + uses: JRMurr/direnv-nix-action@v4.2.0 + - name: run tests + run: just test + shell: bash diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7d7bf09 --- /dev/null +++ b/.gitignore @@ -0,0 +1,232 @@ +# Created by https://www.toptal.com/developers/gitignore/api/git,python,linux,visualstudiocode +# Edit at https://www.toptal.com/developers/gitignore?templates=git,python,linux,visualstudiocode + +### Git ### +# Created by git for backups. To disable backups in Git: +# $ git config --global mergetool.keepBackup false +*.orig + +# Created by git when using merge tools for conflicts +*.BACKUP.* +*.BASE.* +*.LOCAL.* +*.REMOTE.* +*_BACKUP_*.txt +*_BASE_*.txt +*_LOCAL_*.txt +*_REMOTE_*.txt + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/git,python,linux,visualstudiocode + + +# Added by cargo + +/target + +.direnv diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..211d4b1 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,338 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "cc" +version = "1.2.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ee0f8803222ba5a7e2777dd72ca451868909b1ac410621b676adf07280e9b5f" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] +name = "libc" +version = "0.2.175" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a82ae493e598baaea5209805c49bbf2ea7de956d50d7da0da1164f9c6d28543" + +[[package]] +name = "log" +version = "0.4.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "proc-macro2" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61789d7719defeb74ea5fe81f2fdfdbd28a803847077cecce2ff14e1472f6f1" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "psm" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e944464ec8536cd1beb0bbfd96987eb5e3b72f2ecdafdc5c769a37f1fa2ae1f" +dependencies = [ + "cc", +] + +[[package]] +name = "pyo3" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8970a78afe0628a3e3430376fc5fd76b6b45c4d43360ffd6cdd40bdde72b682a" +dependencies = [ + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "458eb0c55e7ece017adeba38f2248ff3ac615e53660d7c71a238d7d2a01c7598" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7114fe5457c61b276ab77c5055f206295b812608083644a5c5b2640c3102565c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8725c0a622b374d6cb051d11a0983786448f7785336139c3c94f5aa6bef7e50" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.25.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4109984c22491085343c05b0dbc54ddc405c3cf7b4374fc533f5c3313a572ccc" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "recursive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0786a43debb760f491b1bc0269fe5e84155353c67482b9e60d0cfb596054b43e" +dependencies = [ + "recursive-proc-macro-impl", + "stacker", +] + +[[package]] +name = "recursive-proc-macro-impl" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" +dependencies = [ + "quote", + "syn", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "sqlparser" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec4b661c54b1e4b603b37873a18c59920e4c51ea8ea2cf527d925424dbd4437c" +dependencies = [ + "log", + "recursive", +] + +[[package]] +name = "sqlquerypp" +version = "0.1.0" +dependencies = [ + "pyo3", + "sqlparser", + "thiserror", +] + +[[package]] +name = "stacker" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cddb07e32ddb770749da91081d8d0ac3a16f1a569a18b20348cd371f5dead06b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys", +] + +[[package]] +name = "syn" +version = "2.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" + +[[package]] +name = "thiserror" +version = "2.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d76d3f064b981389ecb4b6b7f45a0bf9fdac1d5b9204c7bd6714fecc302850" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..a1702ad --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "sqlquerypp" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "sqlquerypp" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = "0.25.0" +sqlparser = "0.58.0" +thiserror = "2.0.15" diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 0000000..ec85a16 --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,82 @@ +# Notes for developers + +## Developing and testing `sqlquerypp` + +NixOS and direnv (the latter with flake and nix-command support enabled) must +be set up to develop with the specified dependencies within the nix flake +(rust compiler, python etc.). + +Following `justfile` commands are helpful for development: + +- `just develop`: compiles everything and installs the latest compiled + state of `sqlquerypp` into the current python virtual environment + which is located at `.venv/` at the repository root. Please note that + you might need to activate it manually using `source .venv/bin/activate`. + +- `just lint` checks whether all coding conventions (as defined in + `pyproject.toml` and `rustfmt.toml`) are fulfilled. + +- `just format` autoformats code according to coding conventions + as much as possible. + +- `just test` runs all lints and tests. + +## Architecture + +This package is mainly separated into two components: + +- High-level Python API: `python/sqlquerypp` + + The high level API architecture itself is very simple. It is recommended + to take a look at the test cases (`python/tests/`) or the documentation + in `sqlquerypp.compiler.Compiler` and its subclasses. + +- Rust API: `src/` + + - `lib.rs` is the main entrypoint to look at. It constructs a module with + the full-qualified name `sqlquerypp.sqlquerypp`. It is internal to the + Python API and exposes internally used, fast SQL preprocessor + implementations. Its python interface declaration is located in + `python/sqlquerypp/sqlquerypp.pyi`. + + - `error.rs`, `lex.rs`, `scanner.rs` and `types.rs` should be quite self- + explaining. + + - The code within `parser/` is responsible for parsing nodes (i.e. + representations of `sqlquerypp` directives) and generating codes + for them. + + - `ParserState` is a state automaton based parser implementation + which does the "magic" transforming `sqlquerypp` code strings + into internal data structures (in terms of compiler construction, + called "nodes" in abstract syntax tree, although `sqlquerypp` + does not provide a correct, academic-style AST-oriented implementation). + + - For example, while parsing `combined_result` instructions are + reflected as `CombinedResultNode` instances + (`src/parser/nodes/combined_result.rs`). These node objects + are obviously very low-level and stateful (many public and + optional fields). + + - When generating code, it's most recommended to use + `CompleteCombinedResultNode` objects. This strategy + applies to all nodes `sqlquerypp` supports. See also: + - `ParserState::finalize()` + - `FinalParserState` + + - `codegen/` provides common structs, traits and functions for + generating valid SQL statements from a `FinalParserState`. + +## Manual release workflow + +- `source .venv/bin/activate` + +- `maturin build --release` + + - if successful, returns output like "Built wheel for CPython 3.13 to 'PATH'" + +- `maturin upload ` (use 'PATH' from last command) + + - **NOTE**: This requires token-based authentication. As this is just a + quick-and-dirty solution which should not be necessary for long, I + won't document this further. diff --git a/README.md b/README.md new file mode 100644 index 0000000..526f4f1 --- /dev/null +++ b/README.md @@ -0,0 +1,63 @@ +# sqlquerypp: SQL query preprocessor + +`sqlquerypp` is a library for preprocessing SQL queries. Main purpose +is writing highly optimized queries with a simplified syntax, allowing +for both maintainability and high performance. + +## Limitations + +Currently, only MySQL 8.4 syntax is supported. + +## Why preprocessing SQL queries? + +SQL (Structed Query Language) follows a declarative paradigm, i.e. a query +explains "what should be done" not "how should it be done". This stands in +contrast to imperative programming, which expresses the "how should a +certain task be fulfilled" aspect. + +Database systems' internals are responsible for maintaining this aspect. +But, however, for certain and large data structures, writing down "naive" +queries sometimes result in poor performance. + +## Supported performance optimizations + +### Combined `UNION` queries + +Consider the following original query: + + ``` + SELECT entity_b.* + FROM entity_b + INNER JOIN entity_a + ON entity_a.id = entity_b.entity_a_id + AND entity_a.criteria = 1337; + ``` + +This is a very simplified example, but if you assume `entity_b` contains very +many items, even correct index conditions may exhaust any DBMS' join buffer. + +An alternative approach might be doing a loop at application side (Python +pseudocode), if network overhead is acceptable: + + ``` + all_matches_in_entity_b = [] + for entity_a_id in [rec.id + for rec in mysql_query("SELECT id FROM entity_a " + "WHERE criteria = 1337")]: + inner_result = mysql_query("SELECT * FROM entity_b " + f"WHERE entity_a_id = {entity_a_id}") + all_matches_in_entity_b += inner_result + ``` + +The following statement, being no valid SQL, translates to a MySQL +native construct of `Recursive Common Table Expression` and `UNION` +fragments when being compiled by `sqlquerypp`. This allows for maximal +query performance, because the inner query with reduced complexity +is still taken into account. At the same time, it grants minimal I/O +overhead as only one query is executed on the database: + + ``` + combined_result (SELECT id FROM entity_a WHERE criteria = 1337) AS $id { + SELECT * FROM entity_b WHERE entity_a_id = $id; + } + ``` \ No newline at end of file diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..88bf936 --- /dev/null +++ b/flake.lock @@ -0,0 +1,82 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1755274400, + "narHash": "sha256-rTInmnp/xYrfcMZyFMH3kc8oko5zYfxsowaLv1LVobY=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "ad7196ae55c295f53a7d1ec39e4a06d922f3b899", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-25.05", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" + } + }, + "rust-overlay": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1755485198, + "narHash": "sha256-C3042ST2lUg0nh734gmuP4lRRIBitA6Maegg2/jYRM4=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "aa45e63d431b28802ca4490cfc796b9e31731df7", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..b76f303 --- /dev/null +++ b/flake.nix @@ -0,0 +1,37 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-25.05"; + flake-utils.url = "github:numtide/flake-utils"; + rust-overlay = { + url = "github:oxalica/rust-overlay"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = { self, nixpkgs, rust-overlay, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + overlays = [ (import rust-overlay) ]; + pkgs = import nixpkgs { + inherit system overlays; + }; + rust = pkgs.rust-bin.nightly.latest.default.override { + extensions = [ + "rust-src" # for rust-analyzer + ]; + }; + in with pkgs; { + devShell = mkShell { + buildInputs = [ + (pkgs.python313.withPackages (ps: [ + ps.pip + ps.ruff + ps.mypy + ])) + just + rust + ]; + }; + } + ); +} diff --git a/justfile b/justfile new file mode 100644 index 0000000..41f853c --- /dev/null +++ b/justfile @@ -0,0 +1,26 @@ +help: + just --list + +develop: + just initialize-venv + .venv/bin/python -m maturin develop + +initialize-venv: + python -m venv .venv + .venv/bin/pip install maturin + +lint: + ruff format --check --diff + mypy --check + cargo fmt --check + cargo clippy --all-targets --all-features -- --deny warnings + +format: + ruff format + cargo fmt + +test: + just develop + just lint + cargo test + python -m unittest discover python/ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..77e282a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,57 @@ +[project] +name = "sqlquerypp" +version = "0.1.0a4" +description = "SQL query preprocessor for generating optimized queries" +readme = "README.md" +repository = "https://github.com/puzzleYOU/sqlquerypp" +requires-python = ">=3.9" +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python", + "Programming Language :: Rust", + "Typing :: Typed", +] + +[project.urls] +source = "https://github.com/puzzleYOU/sqlquerypp" + +[build-system] +requires = ["maturin>=1.9,<2.0"] +build-backend = "maturin" + +[tool.maturin] +python-source = "python" +features = ["pyo3/extension-module"] + +[tool.mypy] +files = "python/**/*.py" +python_version = "3.13" +strict = true + +[tool.ruff] +indent-width = 4 +line-length = 80 +src = ["python"] +target-version = "py313" + +[tool.ruff.lint] +ignore = [] + +[tool.ruff.format] +docstring-code-format = true +docstring-code-line-length = 60 +indent-style = "space" +line-ending = "lf" +quote-style = "double" diff --git a/python/sqlquerypp/__init__.py b/python/sqlquerypp/__init__.py new file mode 100644 index 0000000..a8502c3 --- /dev/null +++ b/python/sqlquerypp/__init__.py @@ -0,0 +1,9 @@ +from .compiler import Compiler, MySQL84Compiler +from .types import Query + +# public API +__all__ = [ + "Compiler", + "MySQL84Compiler", + "Query", +] diff --git a/python/sqlquerypp/compiler.py b/python/sqlquerypp/compiler.py new file mode 100644 index 0000000..a6d2487 --- /dev/null +++ b/python/sqlquerypp/compiler.py @@ -0,0 +1,116 @@ +from abc import ABC, abstractmethod +import hashlib +import re +from typing import Any, Sequence + +from .sqlquerypp import ( + CompiledQueryDescriptor, + preprocess_mysql84_query, +) +from .types import Query + + +class Compiler(ABC): + @abstractmethod + def _compile_template(self, statement: str) -> CompiledQueryDescriptor: + pass + + def __init__(self, variable_placeholder: str = "?") -> None: + self._cache: dict[str, CompiledQueryDescriptor] = {} + self._variable_placeholder = variable_placeholder + + def compile(self, template: Query) -> Query: + """ + Compiles a given query to valid SQL. + """ + descriptor = self._resolve_compiled_descriptor(template.statement) + parameters = self._resolve_parameters_from_descriptor( + template, descriptor + ) + return Query(statement=descriptor.statement, parameters=parameters) + + def _resolve_parameters_from_descriptor( + self, + template: Query, + descriptor: CompiledQueryDescriptor, + ) -> Sequence[Any]: + final_parameters: list[Any] = [] + + last_statement_offset = 0 + last_parameters_offset = 0 + for slice in descriptor.combined_result_node_slices: + parameters_outside_combined_result = template.statement[ + last_statement_offset : slice.scope_begin + ].count(self._variable_placeholder) + parameters_within_combined_result = template.statement[ + slice.scope_begin : slice.scope_end + ].count(self._variable_placeholder) + + final_parameters += template.parameters[ + last_parameters_offset : last_parameters_offset + + parameters_outside_combined_result + ] + last_parameters_offset += parameters_outside_combined_result + + # compiler needs to duplicate params within combined_result nodes. + # that's why we append them twice, but each in order. + for _ in range(2): + final_parameters += template.parameters[ + last_parameters_offset : last_parameters_offset + + parameters_within_combined_result + ] + last_parameters_offset += parameters_within_combined_result + last_statement_offset = slice.scope_end + + if last_statement_offset < len(template.statement): + final_parameters += template.parameters[last_parameters_offset:] + + return final_parameters + + def _resolve_compiled_descriptor( + self, statement: str + ) -> CompiledQueryDescriptor: + key = self._build_cache_key(statement) + if key not in self._cache: + self._cache[key] = self._compile_template(statement) + return self._cache[key] + + def _build_cache_key(self, statement: str) -> str: + normalized = self._get_normalized_query_template_string(statement) + cache_key = hashlib.sha256(normalized).hexdigest() + return cache_key + + def _get_normalized_query_template_string(self, statement: str) -> bytes: + without_new_lines = self._strip_new_lines(statement) + cleaned = re.sub(r"[ ]+", " ", without_new_lines) + encoded = cleaned.encode() + return encoded + + def _strip_new_lines(self, query: str) -> str: + without_cr = query.replace("\r", " ") + without_lf = without_cr.replace("\n", " ") + return without_lf + + +class MySQL84Compiler(Compiler): + """ + An implementation compiling `sqlquerypp` specific syntax to valid MySQL 8.4 + queries. + """ + + def _compile_template(self, statement: str) -> CompiledQueryDescriptor: + if self.pep_249_placeholders: + statement = statement.replace("%s", "?") + + result = preprocess_mysql84_query(statement) + + if self.pep_249_placeholders: + return CompiledQueryDescriptor( + statement=result.statement.replace("?", "%s"), + combined_result_node_slices=result.combined_result_node_slices, + ) + return result + + def __init__(self, pep_249_placeholders: bool = True) -> None: + self.pep_249_placeholders = pep_249_placeholders + super().__init__("%s" if self.pep_249_placeholders else "?") diff --git a/python/sqlquerypp/sqlquerypp.pyi b/python/sqlquerypp/sqlquerypp.pyi new file mode 100644 index 0000000..cac30af --- /dev/null +++ b/python/sqlquerypp/sqlquerypp.pyi @@ -0,0 +1,15 @@ +class CombinedResultNodeSlice: + scope_begin: int + scope_end: int + +class CompiledQueryDescriptor: + statement: str + combined_result_node_slices: list[CombinedResultNodeSlice] + + def __init__( + self, + statement: str, + combined_result_node_slices: list[CombinedResultNodeSlice], + ): ... + +def preprocess_mysql84_query(statement: str) -> CompiledQueryDescriptor: ... diff --git a/python/sqlquerypp/types.py b/python/sqlquerypp/types.py new file mode 100644 index 0000000..bec068e --- /dev/null +++ b/python/sqlquerypp/types.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass +from typing import Any, Sequence + + +@dataclass(frozen=True) +class Query: + statement: str + parameters: Sequence[Any] diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/tests/common.py b/python/tests/common.py new file mode 100644 index 0000000..a08efe9 --- /dev/null +++ b/python/tests/common.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod +import os +from textwrap import dedent +from typing import Any, Sequence +from unittest import TestCase + +from sqlquerypp import Compiler, Query + + +class CompilerTestCase(TestCase, ABC): + maxDiff = None + + @abstractmethod + def _get_compiler(self) -> Compiler: + pass + + def loadQueryFromFile( + self, test_module_name: str, test_function_name: str + ) -> str: + # tests.mysql84.test_something -> mysql84/test_something + module_subpath = "/".join(test_module_name.split(".")[1:]) + final_path = os.path.join( + os.path.dirname(__file__), + "expected_queries", + module_subpath, + f"{test_function_name}.sql", + ) + with open(final_path, "r") as fp: + return fp.read() + + def assertGeneratedQueryEqual( + self, + expected: Query, + template: Query, + ) -> None: + expected = self._normalize_query(expected) + actual_query = self._normalize_query( + self._get_compiler().compile(template) + ) + self._assert_for_equal_statements(expected, actual_query) + self.assertEqual(expected.parameters, actual_query.parameters) + + def _assert_for_equal_statements( + self, expected: Query, actual_query: Query + ) -> None: + msg = f"resulting query is: {actual_query.statement}" + self.assertEqual(expected.statement, actual_query.statement, msg) + + def _normalize_query(self, query: Query) -> Query: + return Query( + statement=dedent(query.statement.strip()), + parameters=query.parameters, + ) diff --git a/python/tests/expected_queries/mysql84/test_combined_result/test_with_multiple_parameters_and_union_fragments.sql b/python/tests/expected_queries/mysql84/test_combined_result/test_with_multiple_parameters_and_union_fragments.sql new file mode 100644 index 0000000..04929fc --- /dev/null +++ b/python/tests/expected_queries/mysql84/test_combined_result/test_with_multiple_parameters_and_union_fragments.sql @@ -0,0 +1,87 @@ +(WITH RECURSIVE all_entries (n, col_a1, col_a2, col_b1, col_b2) AS ( + WITH loop_values AS ( + SELECT + col_a1 + FROM + table_a + WHERE + criteria = %s + ) + SELECT + 0, + a.col_a1, + a.col_a2, + b.col_b1, + b.col_b2 + FROM + table_a AS a + LEFT JOIN table_b AS b ON b.col_a1 = a.col_a1 AND b.cond1 = %s AND b.cond2 = %s + WHERE + a.col_a1 = (SELECT * FROM loop_values LIMIT 1) + UNION ALL + SELECT + n + 1, + a.col_a1, + a.col_a2, + b.col_b1, + b.col_b2 + FROM + all_entries + LEFT JOIN table_a AS a ON a.col_a1 = (SELECT col_a1 FROM loop_values WHERE col_a1 > all_entries.col_a1 LIMIT 1) + LEFT JOIN table_b AS b ON b.col_a1 = a.col_a1 AND b.cond1 = %s AND b.cond2 = %s + WHERE + n + 1 < (SELECT COUNT(*) FROM loop_values) +) +SELECT + col_a1, + col_a2, + col_b1, + col_b2 +FROM + all_entries +WHERE + col_b1 IS NOT NULL AND col_b2 IS NOT NULL) +UNION ALL +(WITH RECURSIVE all_entries (n, col_a1, col_a2, col_b1, col_b2) AS ( + WITH loop_values AS ( + SELECT + col_a1 + FROM + table_a + WHERE + criteria = %s + ) + SELECT + 0, + a.col_a1, + a.col_a2, + b.col_b1, + b.col_b2 + FROM + table_a AS a + LEFT JOIN table_b AS b ON b.col_a1 = a.col_a1 AND b.cond3 = %s AND b.cond4 = %s + WHERE + a.col_a1 = (SELECT * FROM loop_values LIMIT 1) + UNION ALL + SELECT + n + 1, + a.col_a1, + a.col_a2, + b.col_b1, + b.col_b2 + FROM + all_entries + LEFT JOIN table_a AS a ON a.col_a1 = (SELECT col_a1 FROM loop_values WHERE col_a1 > all_entries.col_a1 LIMIT 1) + LEFT JOIN table_b AS b ON b.col_a1 = a.col_a1 AND b.cond3 = %s AND b.cond4 = %s + WHERE + n + 1 < (SELECT COUNT(*) FROM loop_values) +) +SELECT + col_a1, + col_a2, + col_b1, + col_b2 +FROM + all_entries +WHERE + col_b1 IS NOT NULL AND col_b2 IS NOT NULL) \ No newline at end of file diff --git a/python/tests/mysql84/__init__.py b/python/tests/mysql84/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/tests/mysql84/test_combined_result.py b/python/tests/mysql84/test_combined_result.py new file mode 100644 index 0000000..87c8a31 --- /dev/null +++ b/python/tests/mysql84/test_combined_result.py @@ -0,0 +1,45 @@ +from sqlquerypp import Compiler, MySQL84Compiler, Query + +from ..common import CompilerTestCase + + +class CombinedResultTests(CompilerTestCase): + def _get_compiler(self) -> Compiler: + return MySQL84Compiler() + + def test_with_multiple_parameters_and_union_fragments(self) -> None: + template = Query( + """ + combined_result (SELECT col_a1 FROM table_a + WHERE criteria = %s) AS $id { + SELECT a.col_a1, a.col_a2, b.col_b1, b.col_b2 + FROM table_a a + INNER JOIN table_b b + ON b.col_a1 = a.col_a1 + AND b.cond1 = %s + AND b.cond2 = %s + WHERE a.col_a1 = $id + } + UNION ALL + combined_result (SELECT col_a1 FROM table_a + WHERE criteria = %s) AS $id { + SELECT a.col_a1, a.col_a2, b.col_b1, b.col_b2 + FROM table_a a + INNER JOIN table_b b + ON b.col_a1 = a.col_a1 + AND b.cond3 = %s + AND b.cond4 = %s + WHERE a.col_a1 = $id + } + """, + ["CRIT1", 1337, 42, "CRIT2", 31415, 1338], + ) + expected = Query( + self.loadQueryFromFile( + __name__, + "test_with_multiple_parameters_and_union_fragments", + ), + ["CRIT1", 1337, 42, 1337, 42, "CRIT2", 31415, 1338, 31415, 1338], + ) + + self.assertGeneratedQueryEqual(expected, template) diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..a31d40b --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,17 @@ +control_brace_style = "AlwaysNextLine" +error_on_line_overflow = true +format_code_in_doc_comments = true +format_macro_matchers = true +group_imports = "StdExternalCrate" +imports_granularity = "One" +imports_layout = "HorizontalVertical" +indent_style = "Visual" +match_block_trailing_comma = true +max_width = 80 +newline_style = "Unix" +normalize_comments = true +reorder_impl_items = true +spaces_around_ranges = true +struct_lit_single_line = false +use_field_init_shorthand = true +wrap_comments = true \ No newline at end of file diff --git a/src/codegen/common.rs b/src/codegen/common.rs new file mode 100644 index 0000000..5ee0e04 --- /dev/null +++ b/src/codegen/common.rs @@ -0,0 +1,26 @@ +use { + crate::error::QueryCompilerError, + sqlparser::{dialect::GenericDialect, parser::Parser}, +}; + +/// Reformats (i.e. indents and normalizes) a given SQL string to make +/// it more human-readable. +/// +/// This also ensures the query is valid SQL as far the `sqlparser` +/// crate can tell. In case the passed SQL string is invalid, an +/// according error is returned. +pub fn format_query_prettily(query: &str) + -> Result { + let parser = Parser::new(&GenericDialect {}); + let parsed = + parser.try_with_sql(query) + .map_err(|e| { + QueryCompilerError::ResultingQueryInvalid(query.into(), e) + })? + .parse_query() + .map_err(|e| { + QueryCompilerError::ResultingQueryInvalid(query.into(), e) + })?; + + Ok(format!("{:#}", parsed)) +} diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs new file mode 100644 index 0000000..91e55f0 --- /dev/null +++ b/src/codegen/mod.rs @@ -0,0 +1,4 @@ +mod common; +pub mod mysql84; + +pub use mysql84::MySql84QueryCompiler; diff --git a/src/codegen/mysql84.rs b/src/codegen/mysql84.rs new file mode 100644 index 0000000..3c9c165 --- /dev/null +++ b/src/codegen/mysql84.rs @@ -0,0 +1,82 @@ +use { + crate::{ + codegen::common::format_query_prettily, + error::QueryCompilerError, + parser::{FinalParserState, Node}, + types::{CombinedResultNodeSlice, CompiledQueryDescriptor}, + }, + std::cmp::Ordering, +}; + +/// A trait supposed to be implemented upon `FinalParserState`. +pub trait MySql84QueryCompiler { + fn generate_code(&mut self) + -> Result; +} + +/// A trait supposed to be implemented upon any parsed node. +pub trait MySql84NodeCompiler { + fn generate_code(&self) -> Result; +} + +fn get_node_ordering_key(lhs: &impl Node, rhs: &impl Node) -> Ordering { + if lhs.get_end_position() > rhs.get_end_position() + { + Ordering::Less + } + else + { + Ordering::Greater + } +} + +fn process_nodes_in_order(state: &mut FinalParserState) + -> Result<(), QueryCompilerError> { + let mut nodes_in_order = get_all_nodes(state); + nodes_in_order.sort_by(get_node_ordering_key); + for node in nodes_in_order.iter() + { + let original = &state.statement + [node.get_begin_position() .. node.get_end_position() + 1]; + let generated_code = node.generate_code()?; + let replaced = + state.statement + .replace(original, format!("({generated_code:#})").as_str()); + state.statement = replaced; + } + + Ok(()) +} + +fn get_all_nodes(state: &mut FinalParserState) + -> Vec { + // NOTE it should be sufficient to just extend this function in + // case further nodes are being introduced. the remaining code + // should be sufficiently generic + state.combined_result_nodes.clone() +} + +impl MySql84QueryCompiler for FinalParserState { + fn generate_code(&mut self) + -> Result + { + process_nodes_in_order(self)?; + + let combined_result_node_slices = self.combined_result_nodes + .iter() + .map(|node| { + CombinedResultNodeSlice { + scope_begin: node.get_scope_begin_position(), + scope_end: node.get_end_position(), + } + }) + .collect(); + let descriptor = + CompiledQueryDescriptor { statement: + format_query_prettily(self.statement + .as_str())?, + combined_result_node_slices }; + + Ok(descriptor) + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..09209b2 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,29 @@ +use { + pyo3::{exceptions::PyValueError, PyErr}, + sqlparser::parser::ParserError, + thiserror::Error, +}; + +#[derive(Clone, Debug, Error)] +pub enum QueryCompilerError { + #[error("expecting `{0}` after keyword `{1}`")] + MissingCharacter(char, &'static str), + + #[error("nesting `{0}` within `{1}` is not supported")] + UnsupportedNesting(&'static str, &'static str), + + #[error("directive `{0}` at offset `{1}` is incomplete")] + DirectiveIncomplete(&'static str, usize), + + #[error("parsing inner query failed: {0}")] + InnerQueryInvalid(String), + + #[error("resulting query is invalid: {0}, {1}")] + ResultingQueryInvalid(String, ParserError), +} + +impl From for PyErr { + fn from(value: QueryCompilerError) -> Self { + PyValueError::new_err(value.to_string()) + } +} diff --git a/src/lex.rs b/src/lex.rs new file mode 100644 index 0000000..718f36c --- /dev/null +++ b/src/lex.rs @@ -0,0 +1,13 @@ +//! Constants which are used while parsing. + +pub const WORD_DELIMITER: &str = " "; + +pub const KEYWORD_COMBINED_RESULT: &str = "combined_result"; + +pub const VALID_KEYWORDS: [&str; 1] = [KEYWORD_COMBINED_RESULT]; + +pub const PARENTHESE_START: char = '('; +pub const PARENTHESE_END: char = ')'; +pub const BRACE_START: char = '{'; +pub const BRACE_END: char = '}'; +pub const VARIABLE_START: char = '$'; diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1ca8d67 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,57 @@ +use { + crate::{ + parser::ParserState, + types::{CombinedResultNodeSlice, CompiledQueryDescriptor}, + }, + pyo3::prelude::*, +}; + +mod codegen; +mod error; +mod lex; +mod parser; +mod scanner; +mod types; + +/// make_compiler_impl +/// +/// This is a shorthand for generating a high-level compiler function. +macro_rules! make_compiler_impl { + ($func_name:ident, $trait:ty) => { + #[pyfunction] + fn $func_name(statement: String) -> PyResult { + use $trait; + + // First, we construct the parser. See ParserState. + let mut parser = ParserState::initialize(&statement); + + // After that, we do all the lexical checks and parsing systematics. + // The parser now contains a + parser.parse()?; + + // When parsing, the parser deals with "intermediate structs" which + // means, those intermediates heavily make use of "std::Option". + // For the final code generation, it is not much helpful to always + // have to check whether the parsed objects are complete. + // This is what the separate state and the separate `Complete...` + // datastructs are for. See `FinalParserState`. + let mut finalized_state = parser.finalize()?; + + Ok(finalized_state.generate_code()?) + } + }; +} + +make_compiler_impl!(preprocess_mysql84_query, codegen::MySql84QueryCompiler); + +/// Constructs the (internal!) sqlquerypp module containing helper +/// datastructs and compiler implementations. +#[pymodule] +fn sqlquerypp(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(preprocess_mysql84_query, m)?)?; + + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs new file mode 100644 index 0000000..6636bbf --- /dev/null +++ b/src/parser/mod.rs @@ -0,0 +1,117 @@ +mod nodes; +mod state; +mod types; + +pub use { + nodes::Node, + state::{FinalParserState, ParserState}, +}; + +#[cfg(test)] +mod tests { + use crate::parser::{nodes::CompleteCombinedResultNode, ParserState}; + + fn get_combined_result_nodes(query: &str) + -> Vec { + let owned = query.to_string(); + let mut parser = ParserState::initialize(&owned); + parser.parse().unwrap(); + + let finalized = parser.finalize().unwrap(); + finalized.combined_result_nodes.clone() + } + + #[test] + fn no_nodes_from_empty_string() { + let nodes = get_combined_result_nodes(""); + assert_eq!(0, nodes.len()); + } + + #[test] + fn no_nodes_from_sql_query() { + let nodes = get_combined_result_nodes("SELECT * FROM somewhere;"); + assert_eq!(0, nodes.len()); + } + + #[test] + fn node_found() { + let query = " + SELECT * FROM + ( + combined_result (SELECT col_a1 FROM table_a) AS $id_a { + SELECT a.col_a1, a.col_a2, b.col_b1, b.col_b2 FROM table_a a + INNER JOIN table_b b + ON b.col_a1 = a.col_a1 AND b.cond1 = %s AND b.cond2 = %s + WHERE a.col_a1 = $id_a + } + ) + "; + let nodes = get_combined_result_nodes(query); + assert_eq!( + vec![ + CompleteCombinedResultNode::new( + 57, + 371, + "SELECT col_a1 FROM table_a".to_string(), + "$id_a".to_string(), + 111, + "SELECT a.col_a1, a.col_a2, b.col_b1, b.col_b2 FROM table_a a + INNER JOIN table_b b + ON b.col_a1 = a.col_a1 AND b.cond1 = %s AND b.cond2 = %s + WHERE a.col_a1 = $id_a".to_string(), + ), + ], + nodes, + ); + } + + #[test] + fn nodes_found() { + let query = " + SELECT * FROM + ( + combined_result (SELECT col_a1 FROM table_a) AS $id_a { + SELECT a.col_a1, a.col_a2, b.col_b1, b.col_b2 FROM table_a a + INNER JOIN table_b b + ON b.col_a1 = a.col_a1 AND b.cond1 = %s AND b.cond2 = %s + WHERE a.col_a1 = $id_a + } + UNION ALL + combined_result (SELECT col_z1 FROM table_z) AS $id_z { + SELECT z.col_z1, z.col_z2, b.col_b1, b.col_b2 FROM table_z z + INNER JOIN table_b b + ON b.col_z1 = z.col_z1 AND b.cond3 = ? AND b.cond4 = ? + WHERE z.col_z1 = $id_z + } + ) + "; + let nodes = get_combined_result_nodes(query); + assert_eq!( + vec![ + CompleteCombinedResultNode::new( + 57, + 371, + "SELECT col_a1 FROM table_a".to_string(), + "$id_a".to_string(), + 111, + "SELECT a.col_a1, a.col_a2, b.col_b1, b.col_b2 FROM table_a a + INNER JOIN table_b b + ON b.col_a1 = a.col_a1 AND b.cond1 = %s AND b.cond2 = %s + WHERE a.col_a1 = $id_a".to_string() + ), + CompleteCombinedResultNode::new( + 415, + 727, + "SELECT col_z1 FROM table_z".to_string(), + "$id_z".to_string(), + 469, + "SELECT z.col_z1, z.col_z2, b.col_b1, b.col_b2 FROM table_z z + INNER JOIN table_b b + ON b.col_z1 = z.col_z1 AND b.cond3 = ? AND b.cond4 = ? + WHERE z.col_z1 = $id_z".to_string() + ), + ], + nodes, + ); + } +} diff --git a/src/parser/nodes/combined_result.rs b/src/parser/nodes/combined_result.rs new file mode 100644 index 0000000..6dc0ac5 --- /dev/null +++ b/src/parser/nodes/combined_result.rs @@ -0,0 +1,632 @@ +use { + crate::{ + codegen::mysql84::MySql84NodeCompiler, + error::QueryCompilerError, + lex::KEYWORD_COMBINED_RESULT, + parser::nodes::Node, + }, + sqlparser::{ + ast::{helpers::attached_token::AttachedToken, *}, + dialect::GenericDialect, + parser::Parser, + }, +}; + +#[derive(Clone, Debug)] +pub struct CombinedResultNode { + pub begin_position: usize, + pub end_position: Option, + pub iteration_query: Option, + pub iteration_item_variable: Option, + pub inner_query_begin: Option, + pub inner_query: Option, +} + +#[derive(Clone, Debug)] +pub struct CompleteCombinedResultNode { + begin_position: usize, + end_position: usize, + iteration_query: String, + iteration_item_variable: String, + inner_query_begin: usize, + inner_query: String, +} + +impl CompleteCombinedResultNode { + pub fn new(begin_position: usize, + end_position: usize, + iteration_query: String, + iteration_item_variable: String, + inner_query_begin: usize, + inner_query: String) + -> Self { + Self { begin_position, + end_position, + iteration_query, + iteration_item_variable, + inner_query_begin, + inner_query } + } +} + +impl Node for CompleteCombinedResultNode { + fn get_begin_position(&self) -> usize { + self.begin_position + } + + fn get_scope_begin_position(&self) -> usize { + self.inner_query_begin + } + + fn get_end_position(&self) -> usize { + self.end_position + } +} + +fn normalize_query(query: &str) -> String { + query.replace("\n", " ") + .replace("\r", " ") + .split(' ') + .filter(|&el| !el.is_empty()) + .collect::>() + .join(" ") +} + +impl PartialEq for CompleteCombinedResultNode { + fn eq(&self, other: &Self) -> bool { + self.begin_position == other.begin_position + && self.end_position == other.end_position + && self.iteration_query == other.iteration_query + && self.iteration_item_variable == other.iteration_item_variable + && self.inner_query_begin == other.inner_query_begin + && normalize_query(&self.inner_query) + == normalize_query(&other.inner_query) + } +} + +impl CombinedResultNode { + pub fn new(begin_position: usize) -> Self { + Self { begin_position, + end_position: None, + iteration_query: None, + iteration_item_variable: None, + inner_query_begin: None, + inner_query: None } + } +} + +impl TryFrom for CompleteCombinedResultNode { + type Error = QueryCompilerError; + + fn try_from(value: CombinedResultNode) -> Result { + if value.iteration_query.is_none() + || value.end_position.is_none() + || value.iteration_item_variable.is_none() + || value.inner_query_begin.is_none() + || value.inner_query.is_none() + { + let err = QueryCompilerError::DirectiveIncomplete( + KEYWORD_COMBINED_RESULT, + value.begin_position); + return Err(err); + } + + let node = + CompleteCombinedResultNode::new(value.begin_position, + value.end_position.unwrap(), + value.iteration_query.unwrap(), + value.iteration_item_variable + .unwrap(), + value.inner_query_begin.unwrap(), + value.inner_query.unwrap()); + Ok(node) + } +} + +impl From for QueryCompilerError { + fn from(value: sqlparser::parser::ParserError) -> Self { + Self::InnerQueryInvalid(value.to_string()) + } +} + +impl MySql84NodeCompiler for CompleteCombinedResultNode { + fn generate_code(&self) -> Result { + let original_select = + prepare_parser_with_query(&self.inner_query)?.parse_select()?; + + let final_select = compile_final_select(&original_select, self)?; + + Ok(final_select.to_string()) + } +} + +fn compile_final_select(original_select: &Select, + node: &CompleteCombinedResultNode) + -> Result, QueryCompilerError> { + let original_select_column_idents = + derive_original_select_columns(original_select); + + let cte_columns = + construct_recursive_cte_columns(&original_select_column_idents); + let cte_statement = + construct_recursive_cte_statement(original_select, cte_columns, node)?; + + let select = compile_recursive_cte(original_select, cte_statement)?; + Ok(select) +} + +fn compile_recursive_cte(original_select: &Select, + cte_statement: With) + -> Result, QueryCompilerError> { + let original_select_column_idents = + derive_original_select_columns(original_select); + + let where_fragments = + derive_joined_table_column_names(original_select) + .unwrap_or_default() + .iter() + .map(|identifier| format!("{identifier} IS NOT NULL")) + .collect::>() + .join(" AND "); + + let mut select = prepare_parser_with_query(format!( + "SELECT * FROM all_entries WHERE {where_fragments}" + ).as_str())?.parse_query()?; + select.with = Some(cte_statement); + let mut select_body = select.body + .as_select() + .expect("our own query is a valid SELECT") + .clone(); + select_body.projection = original_select_column_idents + .into_iter() + .map(|ident| SelectItem::UnnamedExpr(Expr::Identifier(ident))) + .collect(); + select.body = Box::new(SetExpr::Select(Box::new(select_body))); + Ok(select) +} + +fn derive_joined_table_column_names(original_select: &Select) + -> Option> { + let original_select_column_ident_pairs = + derive_fully_qualified_original_select_columns(original_select); + + if let TableFactor::Table { name, + alias, + .. } = &original_select.from[0].relation + { + let target_to_elide = if alias.is_some() + { + alias.clone().unwrap().to_string() + } + else + { + name.to_string() + }; + + return Some(original_select_column_ident_pairs.iter() + .filter(|&pair| { + pair.0.value + != target_to_elide + }) + .map(|pair| { + pair.1.value.clone() + }) + .collect()); + } + + None +} + +fn construct_cte_with_iteration(node: &CompleteCombinedResultNode) + -> Result { + let cte = + Cte { alias: TableAlias { name: "loop_values".into(), + columns: vec![] }, + query: + prepare_parser_with_query(&node.iteration_query)? + .parse_query()?, + from: None, + materialized: None, + closing_paren_token: AttachedToken::empty() }; + let stmt = With { cte_tables: vec![cte], + recursive: false, + with_token: AttachedToken::empty() }; + Ok(stmt) +} + +fn construct_recursive_cte_statement(original_select: &Select, + cte_columns: Vec, + node: &CompleteCombinedResultNode) + -> Result { + let cte = + Cte { alias: TableAlias { name: "all_entries".into(), + columns: cte_columns }, + query: + Box::new(construct_recursive_cte_query(original_select, + node)?), + from: None, + materialized: None, + closing_paren_token: AttachedToken::empty() }; + let stmt = With { cte_tables: vec![cte], + recursive: true, + with_token: AttachedToken::empty() }; + Ok(stmt) +} + +fn construct_recursive_cte_query(original_select: &Select, + node: &CompleteCombinedResultNode) + -> Result { + Ok(Query { + body: Box::new( + SetExpr::SetOperation { + op: SetOperator::Union, + set_quantifier: SetQuantifier::All, + left: compile_cte_anchor(original_select, node)?, + right: compile_cte_loop(original_select, node)?, + } + ), + with: Some(construct_cte_with_iteration(node)?), + order_by: None, + limit_clause: None, + fetch: None, + locks: vec![], + for_clause: None, + settings: None, + format_clause: None, + pipe_operators: vec![], + }) +} + +fn construct_recursive_cte_columns(original_select_column_idents: &[Ident]) + -> Vec { + let mut cte_idents: Vec = original_select_column_idents.into(); + cte_idents.insert(0, Ident::new("n")); + + cte_idents.into_iter() + .map(|ident| TableAliasColumnDef { name: ident, + data_type: None }) + .collect() +} + +fn derive_original_select_columns(original_select: &Select) -> Vec { + original_select.projection + .iter() + .filter_map(convert_select_item_to_ident_option) + .collect() +} + +fn derive_fully_qualified_original_select_columns(original_select: &Select) + -> Vec<(Ident, Ident)> { + original_select.projection + .iter() + .filter_map( + convert_select_item_to_full_qualified_idents_option + ) + .collect() +} + +fn convert_select_item_to_ident_option(item: &SelectItem) -> Option { + match item + { + SelectItem::UnnamedExpr(Expr::Identifier(ident)) => Some(ident.clone()), + SelectItem::UnnamedExpr(Expr::CompoundIdentifier(idents)) => + { + idents.last().cloned() + }, + _ => None, + } +} + +fn convert_select_item_to_full_qualified_idents_option( + item: &SelectItem) + -> Option<(Ident, Ident)> { + match item + { + SelectItem::UnnamedExpr(Expr::CompoundIdentifier(idents)) + if idents.len() == 2 => + { + Some((idents[0].clone(), idents[1].clone())) + }, + _ => None, + } +} + +fn compile_cte_anchor(original_select: &Select, + node: &CompleteCombinedResultNode) + -> Result, QueryCompilerError> { + let mut cte_anchor = original_select.clone(); + if cte_anchor.from.len() != 1 + { + let msg = "inner query may only have one table following + FROM directive"; + return Err(QueryCompilerError::InnerQueryInvalid(msg.into())); + } + + insert_anchor_iteration_index(&mut cte_anchor)?; + transform_all_joins_to_left_joins(&mut cte_anchor); + apply_anchor_iteration_variable(&mut cte_anchor, node)?; + + Ok(Box::new(SetExpr::Select(Box::new(cte_anchor)))) +} + +fn apply_anchor_iteration_variable(cte_anchor: &mut Select, + node: &CompleteCombinedResultNode) + -> Result<(), QueryCompilerError> { + if let Some(selection) = &cte_anchor.selection + { + let transformed_selection_fragment = + selection.to_string() + .replace(&node.iteration_item_variable, + "(SELECT * FROM loop_values LIMIT 1)"); + let transformed_selection = + prepare_parser_with_query(&transformed_selection_fragment)? + .parse_expr()?; + cte_anchor.selection = Some(transformed_selection); + } + Ok(()) +} + +fn transform_all_joins_to_left_joins(select: &mut Select) { + select + .from[0] + .joins + .iter_mut() + .for_each(|join| { + if let Some(constraint) = derive_join_constraint(join) { + join.join_operator = JoinOperator::Left(constraint.clone()); + } + }); +} + +fn insert_anchor_iteration_index(cte_anchor: &mut Select) + -> Result<(), QueryCompilerError> { + cte_anchor.projection.insert( + 0, + SelectItem::UnnamedExpr( + prepare_parser_with_query("0")?.parse_expr()?, + ) + ); + Ok(()) +} + +fn derive_join_constraint(join: &Join) -> Option<&JoinConstraint> { + match &join.join_operator + { + JoinOperator::Left(constraint) => Some(constraint), + JoinOperator::Join(constraint) => Some(constraint), + JoinOperator::Inner(constraint) => Some(constraint), + JoinOperator::LeftOuter(constraint) => Some(constraint), + JoinOperator::Right(constraint) => Some(constraint), + JoinOperator::RightOuter(constraint) => Some(constraint), + JoinOperator::FullOuter(constraint) => Some(constraint), + JoinOperator::CrossJoin => None, + JoinOperator::Semi(constraint) => Some(constraint), + JoinOperator::LeftSemi(constraint) => Some(constraint), + JoinOperator::RightSemi(constraint) => Some(constraint), + JoinOperator::Anti(constraint) => Some(constraint), + JoinOperator::LeftAnti(constraint) => Some(constraint), + JoinOperator::RightAnti(constraint) => Some(constraint), + JoinOperator::CrossApply => None, + JoinOperator::OuterApply => None, + JoinOperator::AsOf { match_condition: _, + constraint, } => Some(constraint), + JoinOperator::StraightJoin(constraint) => Some(constraint), + } +} + +fn compile_cte_loop(original_select: &Select, + node: &CompleteCombinedResultNode) + -> Result, QueryCompilerError> { + let mut cte_loop = original_select.clone(); + if cte_loop.from.len() != 1 + { + let msg = "inner query may only have one table following + FROM directive"; + return Err(QueryCompilerError::InnerQueryInvalid(msg.into())); + } + insert_loop_iteration_index(&mut cte_loop)?; + transform_loop_table_name_to_cte_alias(&mut cte_loop); + transform_all_joins_to_left_joins(&mut cte_loop); + add_loop_join(&mut cte_loop, node)?; + finalize_selection(&mut cte_loop)?; + + Ok(Box::new(SetExpr::Select(Box::new(cte_loop)))) +} + +fn finalize_selection(cte_loop: &mut Select) -> Result<(), QueryCompilerError> { + let lhs = prepare_parser_with_query("n + 1")?.parse_expr()?; + + let subquery = + prepare_parser_with_query("SELECT COUNT(*) FROM loop_values")? + .parse_query()?; + + cte_loop.selection = + Some(Expr::BinaryOp { left: Box::new(lhs), + op: BinaryOperator::Lt, + right: Box::new(Expr::Subquery(subquery)) }); + + Ok(()) +} + +fn add_loop_join(cte_loop: &mut Select, + node: &CompleteCombinedResultNode) + -> Result<(), QueryCompilerError> { + let (loop_target_table_or_alias, loop_target_column) = + extract_table_and_column_for_iteration_variable(cte_loop, + &node.iteration_item_variable)?; + let loop_target_column_name = loop_target_column.clone().value; + + let inner_select_table_name = extract_iteration_query_table_name(node)?; + + let constraint = + construct_loop_join_constraint(&loop_target_table_or_alias, + loop_target_column, + loop_target_column_name)?; + + let table_factor = + construct_loop_join_table_factor(loop_target_table_or_alias, + inner_select_table_name)?; + + let join = Join { join_operator: JoinOperator::Left(constraint), + relation: table_factor, + global: false }; + cte_loop.from[0].joins.insert(0, join); + Ok(()) +} + +fn construct_loop_join_table_factor( + loop_target_table_or_alias: Ident, + inner_select_table_name: String) + -> Result { + let join_alias = loop_target_table_or_alias.value; + let table_factor = + prepare_parser_with_query( + format!("{inner_select_table_name} AS {join_alias}") + .as_str())? + .parse_table_factor()?; + Ok(table_factor) +} + +fn construct_loop_join_constraint( + loop_target_table_or_alias: &Ident, + loop_target_column: Ident, + loop_target_column_name: String) + -> Result { + let join_subquery = + prepare_parser_with_query( + format!("SELECT {loop_target_column_name} FROM loop_values + WHERE {loop_target_column_name} + > all_entries.{loop_target_column_name} + LIMIT 1").as_str() + )?.parse_query()?; + let constraint = + JoinConstraint::On( + Expr::BinaryOp { + left: Box::new(Expr::CompoundIdentifier(vec![ + loop_target_table_or_alias.clone(), + loop_target_column, + ])), + op: BinaryOperator::Eq, + right: Box::new(Expr::Subquery(join_subquery)), + } + ); + Ok(constraint) +} + +fn extract_iteration_query_table_name(node: &CompleteCombinedResultNode) + -> Result { + let inner_select = + prepare_parser_with_query(&node.inner_query)?.parse_select()?; + if inner_select.from.len() != 1 + { + return Err(QueryCompilerError::InnerQueryInvalid( + "expected loop iteration query to have just one table + followed by FROM".into())); + } + let inner_select_table_name = match &inner_select.from[0].relation { + TableFactor::Table { name, ..} => Some(name.to_string()), + _ => None, + }.ok_or(QueryCompilerError::InnerQueryInvalid( + "could not derive table name from loop iteration query".into()))?; + Ok(inner_select_table_name) +} + +fn extract_table_and_column_for_iteration_variable( + cte_loop: &mut Select, + iteration_variable: &str) + -> Result<(Ident, Ident), QueryCompilerError> { + let idents = + extract_iteration_variable_idents(cte_loop, iteration_variable)?; + if idents.len() != 2 + { + return Err(QueryCompilerError::InnerQueryInvalid( + "expected lvalue of iteration variable to be + of schema: table.column".into())); + } + Ok((idents[0].clone(), idents[1].clone())) +} + +fn extract_iteration_variable_idents( + cte_loop: &mut Select, + iteration_variable: &str) + -> Result, QueryCompilerError> { + let err_candidate = QueryCompilerError::InnerQueryInvalid( + "should contain iteration variable".into()); + let selection = cte_loop.selection.clone().ok_or(err_candidate.clone())?; + let candidate = match &selection + { + Expr::BinaryOp { left: _, + op: _, + right: _, } => + { + get_iteration_target_identifier(&selection, + iteration_variable) + }, + _ => None, + }.ok_or(err_candidate.clone())?; + let idents = match candidate + { + Expr::CompoundIdentifier(idents) => Some(idents), + _ => None, + }.ok_or(err_candidate)?; + Ok(idents) +} + +fn get_iteration_target_identifier(expr: &Expr, + iteration_variable: &str) + -> Option { + match expr + { + Expr::BinaryOp { left, + op, + right, } => + { + if *op == BinaryOperator::Eq + { + if let Expr::Value(value_with_span) = &**right + { + if let Value::Placeholder(var) = &value_with_span.value + { + if var == iteration_variable + { + return Some(*left.clone()); + } + } + } + } + get_iteration_target_identifier(expr, iteration_variable) + }, + _ => None, + } +} + +fn transform_loop_table_name_to_cte_alias(cte_loop: &mut Select) { + cte_loop.from[0].relation = + TableFactor::Table { name: vec!["all_entries".into()].into(), + alias: None, + args: None, + with_hints: vec![], + version: None, + with_ordinality: false, + partitions: vec![], + json_path: None, + sample: None, + index_hints: vec![] }; +} + +fn insert_loop_iteration_index(cte_loop: &mut Select) + -> Result<(), QueryCompilerError> { + cte_loop.projection.insert( + 0, + SelectItem::UnnamedExpr( + prepare_parser_with_query("n + 1")?.parse_expr()?, + ) + ); + Ok(()) +} + +fn prepare_parser_with_query(query: &str) + -> Result, QueryCompilerError> { + let parser = sqlparser::parser::Parser::new(&GenericDialect {}); + Ok(parser.try_with_sql(query)?) +} diff --git a/src/parser/nodes/mod.rs b/src/parser/nodes/mod.rs new file mode 100644 index 0000000..c5ce168 --- /dev/null +++ b/src/parser/nodes/mod.rs @@ -0,0 +1,9 @@ +mod combined_result; + +pub use combined_result::{CombinedResultNode, CompleteCombinedResultNode}; + +pub trait Node { + fn get_begin_position(&self) -> usize; + fn get_scope_begin_position(&self) -> usize; + fn get_end_position(&self) -> usize; +} diff --git a/src/parser/state.rs b/src/parser/state.rs new file mode 100644 index 0000000..817c150 --- /dev/null +++ b/src/parser/state.rs @@ -0,0 +1,300 @@ +use crate::{ + error::*, + lex::*, + parser::{ + nodes::{CombinedResultNode, CompleteCombinedResultNode}, + types::NodesState, + }, + scanner::{get_mandatory_succeeding_character_position, TokenState}, +}; + +/// Reflects the current parser state. +/// +/// This makes heavy use of optionals and statefulness (i.e. is +/// heavily passed around as a mutable reference). It is quite +/// meant as a intermediate parser state automaton. +/// +/// For the codegen phase, it's not recommended to use this low-level +/// intermediate state automaton. See `ParserState::finalize()`. +pub struct ParserState<'t> { + statement: &'t String, + seen_token_state: Option, + combined_result_nodes_state: NodesState, + offset: usize, +} + +/// The final parser state. See `ParserState::finalize`. +pub struct FinalParserState { + pub statement: String, + pub combined_result_nodes: Vec, +} + +impl<'t> ParserState<'t> { + pub fn initialize(statement: &'t String) -> Self { + Self { statement, + seen_token_state: None, + combined_result_nodes_state: NodesState::new(), + offset: 0 } + } + + /// Steps through the given statement, word by word, and internally + /// updates the parser state accordingly (i.e. saves which parsed + /// objects have been seen and which data they contain). + pub fn parse(&mut self) -> Result<(), QueryCompilerError> { + for word in self.statement.split(WORD_DELIMITER) + { + self.advance_word(word)?; + self.advance_offset(word); + } + Ok(()) + } + + /// Transforms the intermediate state automaton into a `FinalParserState`. + /// + /// It's recommended to use this transformation because: + /// - In contrast to the low-level state automaton it only makes use of + /// optionals where parsed objects' data model explicitly requires it. + /// - It checks whether the parsed objects are as complete as the codegen + /// phase requires (i.e. whether the parsed SQL code was incomplete or + /// otherwise semantically invalid). + pub fn finalize(&'t self) -> Result { + let final_state = FinalParserState { statement: self.statement + .clone(), + combined_result_nodes: + self.get_complete_nodes()? }; + Ok(final_state) + } + + fn get_complete_nodes( + &self) + -> Result, QueryCompilerError> { + let converted = + self.combined_result_nodes_state + .all_nodes + .iter() + .map(|n| n.clone().try_into()) + .collect::>>(); + + let ok_variants = + converted.iter() + .filter_map(|el| el.clone().ok()) + .collect::>(); + + let error_variants = converted.iter() + .filter_map(|el| el.clone().err()) + .collect::>(); + + if !error_variants.is_empty() + { + return Err(error_variants[0].clone()); + } + + Ok(ok_variants) + } + + fn advance_word(&mut self, word: &str) -> Result<(), QueryCompilerError> { + if let Some(next) = self.try_forward_to_keyword_seen_state(word)? + { + self.seen_token_state = Some(next); + return Ok(()); + } + + if let Some(next) = + self.try_forward_to_initiator_char_based_state(word)? + { + self.seen_token_state = Some(next); + return Ok(()); + } + + Ok(()) + } + + fn advance_offset(&mut self, word: &str) { + self.offset += word.len() + WORD_DELIMITER.len(); + } + + fn try_forward_to_keyword_seen_state( + &mut self, + word: &str) + -> Result, QueryCompilerError> { + if VALID_KEYWORDS.contains(&word) + { + let current = + TokenState::from_keyword(word.to_string(), self.offset) + .expect("checked per .contains() above"); + self.handle_transition(¤t)?; + return Ok(Some(current)); + } + + Ok(None) + } + + fn try_forward_to_initiator_char_based_state( + &mut self, + word: &str) + -> Result, QueryCompilerError> { + if let Some(initiator) = word.chars().nth(0) + { + let state_candidate = + self.get_scanner_state_by_char_token(initiator); + + if let Some(state) = &state_candidate + { + self.handle_transition(state)?; + } + + return Ok(state_candidate); + } + + Ok(None) + } + + fn get_scanner_state_by_char_token(&self, + token: char) + -> Option { + match token + { + BRACE_START => Some(TokenState::OpeningBrace(self.offset)), + BRACE_END => Some(TokenState::ClosingBrace(self.offset)), + PARENTHESE_START => + { + Some(TokenState::OpeningParenthese(self.offset)) + }, + VARIABLE_START => Some(TokenState::Variable(self.offset)), + _ => None, + } + } + + fn handle_transition(&mut self, + current_token_state: &TokenState) + -> Result<(), QueryCompilerError> { + let handles_combined_result_node = + self.combined_result_nodes_state.current_node.is_some(); + + match (&self.seen_token_state, current_token_state) + { + (_, TokenState::CombinedResultsKeyword(offset)) => + { + self.handle_combined_results_keyword(offset)? + }, + + (_, TokenState::OpeningParenthese(offset)) + if handles_combined_result_node => + { + self.attach_iteration_query(*offset + 1)? + }, + + (_, TokenState::Variable(offset)) + if handles_combined_result_node => + { + self.attach_variable(*offset)? + }, + + (_, TokenState::OpeningBrace(offset)) + if handles_combined_result_node => + { + self.mark_inner_query_begin(*offset)? + }, + + (_, TokenState::ClosingBrace(offset)) + if handles_combined_result_node => + { + self.finalize_combined_result_node(offset) + }, + + _ => + {}, + } + + Ok(()) + } + + fn handle_combined_results_keyword(&mut self, + offset: &usize) + -> Result<(), QueryCompilerError> { + if self.combined_result_nodes_state.current_node.is_some() + { + let err = + QueryCompilerError::UnsupportedNesting(KEYWORD_COMBINED_RESULT, + KEYWORD_COMBINED_RESULT); + return Err(err); + } + self.combined_result_nodes_state.current_node = + Some(CombinedResultNode::new(*offset)); + Ok(()) + } + + fn attach_iteration_query(&mut self, + cursor: usize) + -> Result<(), QueryCompilerError> { + if let Some(node) = &mut self.combined_result_nodes_state.current_node + { + let brace_start_pos = get_mandatory_succeeding_character_position( + cursor, + self.statement.len(), + self.statement, + BRACE_START, + KEYWORD_COMBINED_RESULT, + )?; + + let closing_brace_pos = + get_mandatory_succeeding_character_position( + cursor, + brace_start_pos, + self.statement, + PARENTHESE_END, + KEYWORD_COMBINED_RESULT, + )?; + + node.iteration_query = + Some(self.statement[cursor .. closing_brace_pos].into()); + } + Ok(()) + } + + fn attach_variable(&mut self, + cursor: usize) + -> Result<(), QueryCompilerError> { + let words_beyond_cursor = + self.statement[cursor ..].split(WORD_DELIMITER); + if let Some(found_variable) = words_beyond_cursor.into_iter().nth(0) + { + if let Some(node) = + &mut self.combined_result_nodes_state.current_node + { + node.iteration_item_variable = + Some(found_variable.trim().into()); + } + } + Ok(()) + } + + fn mark_inner_query_begin(&mut self, + cursor: usize) + -> Result<(), QueryCompilerError> { + if let Some(node) = &mut self.combined_result_nodes_state.current_node + { + node.inner_query_begin = Some(cursor); + } + Ok(()) + } + + fn finalize_combined_result_node(&mut self, offset: &usize) { + if let Some(node) = &mut self.combined_result_nodes_state.current_node + { + if let Some(begin) = node.inner_query_begin + { + let slice_start = begin + 1; + let slice_end = *offset - 1; + let slice = &self.statement[slice_start .. slice_end]; + node.inner_query = Some(slice.trim().into()); + } + node.end_position = Some(*offset); + self.combined_result_nodes_state + .all_nodes + .push(node.clone()); + self.combined_result_nodes_state.current_node = None; + } + } +} diff --git a/src/parser/types.rs b/src/parser/types.rs new file mode 100644 index 0000000..e840631 --- /dev/null +++ b/src/parser/types.rs @@ -0,0 +1,11 @@ +pub struct NodesState { + pub all_nodes: Vec, + pub current_node: Option, +} + +impl NodesState { + pub fn new() -> Self { + Self { all_nodes: vec![], + current_node: None } + } +} diff --git a/src/scanner.rs b/src/scanner.rs new file mode 100644 index 0000000..ac50147 --- /dev/null +++ b/src/scanner.rs @@ -0,0 +1,50 @@ +use crate::{error::QueryCompilerError, lex::*}; + +/// Reflects choices which token has been seen recently. +pub enum TokenState { + OpeningParenthese(usize), + OpeningBrace(usize), + ClosingBrace(usize), + CombinedResultsKeyword(usize), + Variable(usize), +} + +impl TokenState { + pub fn from_keyword(keyword: String, offset: usize) -> Option { + match keyword.as_str() + { + KEYWORD_COMBINED_RESULT => + { + Some(TokenState::CombinedResultsKeyword(offset)) + }, + _ => None, + } + } +} + +/// Returns the position of a required character. +/// +/// - `cursor` and `end` determine which substring of `statement` should be +/// scanned. +/// - `character`: the character whose position should be returned +/// - `keyword`: relevant for constructing the error message in case the +/// character has not been found. Primarily meant for constructing an error +/// message with semantics like "expected 'combined_result' (the keyword) +/// should have been closed with '}' (the character)". +/// +/// The returned offset is absolute to the entire statement, not just +/// the scanned slice. +pub fn get_mandatory_succeeding_character_position( + cursor: usize, + end: usize, + statement: &str, + character: char, + keyword: &'static str) + -> Result { + Ok(cursor + + statement[cursor..end] + .find(character) + .ok_or( + QueryCompilerError::MissingCharacter( + character, keyword))?) +} diff --git a/src/types.rs b/src/types.rs new file mode 100644 index 0000000..d4084ae --- /dev/null +++ b/src/types.rs @@ -0,0 +1,32 @@ +//! Datastructs for python bindings. +use pyo3::{pyclass, pymethods}; + +#[pyclass] +#[derive(Clone)] +pub struct CombinedResultNodeSlice { + #[pyo3(get)] + pub scope_begin: usize, + + #[pyo3(get)] + pub scope_end: usize, +} + +#[pyclass] +pub struct CompiledQueryDescriptor { + #[pyo3(get)] + pub statement: String, + + #[pyo3(get)] + pub combined_result_node_slices: Vec, +} + +#[pymethods] +impl CompiledQueryDescriptor { + #[new] + fn new(statement: String, + combined_result_node_slices: Vec) + -> Self { + Self { statement, + combined_result_node_slices } + } +}