diff --git a/.clangd b/.clangd new file mode 100644 index 00000000..7652669f --- /dev/null +++ b/.clangd @@ -0,0 +1,4 @@ +CompileFlags: + Add: + - "-std=c++23" + - "-D__cpp_concepts=202002L" diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..2eef0514 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,15 @@ +root = true + +[*.{md,txt}]] +indent_style = space +indent_size = 2 +insert_final_newline = true +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true + +[*.xml] +indent_size = 2 + +[.github/**/*.yml] +indent_size = 2 diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..ded2e36e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,16 @@ +# Set the default behavior, in case people don't have core.autocrlf set. +* text=auto + +# Explicitly declare text files you want to always be normalized and converted +# to native line endings on checkout. +*.cpp text eol=lf +*.hpp text eol=lf +*.txt text eol=lf + +# Declare files that will always have CRLF line endings on checkout. +*.sln text eol=crlf +*.vcxproj text eol=crlf + +# Denote all files that are truly binary and should not be modified. +*.png binary +*.jpg binary diff --git a/.github/prepare-test-run.sh b/.github/prepare-test-run.sh new file mode 100755 index 00000000..8125fcac --- /dev/null +++ b/.github/prepare-test-run.sh @@ -0,0 +1,209 @@ +#! /usr/bin/env bash + +# This script is used to prepare the test run environment on Github Actions. + +DBMS="$1" # One of: "SQLite3", "MS SQL Server 2019", "MS SQL Server 2022" "PostgreSQL", "Oracle", "MySQL" + +# Password to be set for the test suite with sufficient permissions (CREATE DATABASE, DROP DATABASE, ...) +DB_PASSWORD="BlahThat." +DB_NAME="LightweightTest" + +setup_sqlite3() { + echo "Setting up SQLite3..." + sudo apt install -y \ + libsqlite3-dev \ + libsqliteodbc \ + sqlite3 \ + unixodbc-dev + + if [ -n "$GITHUB_OUTPUT" ]; then + echo "Exporting ODBC_CONNECTION_INFO..." + # expose the ODBC connection string to connect to the database + echo "ODBC_CONNECTION_STRING=DRIVER=SQLite3;DATABASE=file::memory:" >> "${GITHUB_OUTPUT}" + fi +} + +setup_sqlserver() { + # References: + # https://learn.microsoft.com/en-us/sql/linux/sample-unattended-install-ubuntu?view=sql-server-ver16 + # https://learn.microsoft.com/en-us/sql/tools/sqlcmd/sqlcmd-utility + # https://learn.microsoft.com/en-us/sql/linux/quickstart-install-connect-docker + + set -x + local MSSQL_PID='evaluation' + local SS_VERSION="$1" + local UBUNTU_RELEASE="20.04" # we fixiate the version, because the latest isn't always available by MS -- "$(lsb_release -r | awk '{print $2}') + + echo "Installing sqlcmd..." + curl https://packages.microsoft.com/keys/microsoft.asc | sudo tee /etc/apt/trusted.gpg.d/microsoft.asc + sudo add-apt-repository "$(wget -qO- https://packages.microsoft.com/config/ubuntu/${UBUNTU_RELEASE}/prod.list)" + sudo apt update + sudo apt install sqlcmd + + echo "Installing ODBC..." + sudo ACCEPT_EULA=y DEBIAN_FRONTEND=noninteractive apt install -y unixodbc-dev unixodbc odbcinst mssql-tools18 + dpkg -L mssql-tools18 + + echo "ODBC drivers installed:" + sudo odbcinst -q -d + + echo "Querying ODBC driver for MS SQL Server..." + sudo odbcinst -q -d -n "ODBC Driver 18 for SQL Server" + + echo "Pulling SQL Server ${SS_VERSION} image..." + docker pull mcr.microsoft.com/mssql/server:${SS_VERSION}-latest + + echo "Starting SQL Server ${SS_VERSION}..." + docker run \ + -e "ACCEPT_EULA=Y" \ + -e "MSSQL_SA_PASSWORD=${DB_PASSWORD}" \ + -p 1433:1433 --name sql${SS_VERSION} --hostname sql${SS_VERSION} \ + -d \ + "mcr.microsoft.com/mssql/server:${SS_VERSION}-latest" + + docker ps -a + set +x + + echo "Wait for the SQL Server to start..." + counter=1 + errstatus=1 + while [ $counter -le 15 ] && [ $errstatus = 1 ] + do + echo "$counter..." + sleep 1s + sqlcmd \ + -S localhost \ + -U SA \ + -P "$DB_PASSWORD" \ + -Q "SELECT @@VERSION" 2>/dev/null + errstatus=$? + ((counter++)) + done + + # create a test database + sqlcmd -S localhost -U SA -P "${DB_PASSWORD}" -Q "CREATE DATABASE ${DB_NAME}" + + if [ -n "$GITHUB_OUTPUT" ]; then + echo "Exporting ODBC_CONNECTION_INFO..." + # expose the ODBC connection string to connect to the database server + echo "ODBC_CONNECTION_STRING=DRIVER={ODBC Driver 18 for SQL Server};SERVER=localhost;PORT=1433;UID=SA;PWD=${DB_PASSWORD};TrustServerCertificate=yes;DATABASE=${DB_NAME}" >> "${GITHUB_OUTPUT}" + fi +} + +setup_postgres() { + echo "Setting up PostgreSQL..." + # For Fedora: sudo dnf -y install postgresql-server postgresql-odbc + sudo apt install -y \ + postgresql \ + postgresql-contrib \ + libpq-dev \ + odbc-postgresql + + sudo postgresql-setup --initdb --unit postgresql + + # check Postgres, version, and ODBC installation + sudo systemctl start postgresql + psql -V + odbcinst -q -d + odbcinst -q -d -n "PostgreSQL ANSI" + odbcinst -q -d -n "PostgreSQL Unicode" + + echo "Wait for the PostgreSQL server to start..." + counter=1 + errstatus=1 + while [ $counter -le 15 ] && [ $errstatus = 1 ] + do + echo "$counter..." + pg_isready -U postgres + errstatus=$? + ((counter++)) + done + + echo "ALTER USER postgres WITH PASSWORD '$DB_PASSWORD';" > setpw.sql + sudo -u postgres psql -f setpw.sql + rm setpw.sql + + echo "Create database user..." + local DB_USER="$USER" + sudo -u postgres psql -c "CREATE USER $DB_USER WITH SUPERUSER PASSWORD '$DB_PASSWORD'" + + echo "Create database..." + sudo -u postgres createdb $DB_NAME + + if [ -n "$GITHUB_OUTPUT" ]; then + echo "Exporting ODBC_CONNECTION_INFO..." + echo "ODBC_CONNECTION_STRING=Driver={PostgreSQL ANSI};Server=localhost;Port=5432;Uid=$DB_USER;Pwd=$DB_PASSWORD;Database=$DB_NAME" >> "${GITHUB_OUTPUT}" + fi +} + +setup_oracle() { + echo "Setting up Oracle..." # TODO + + # References + # - https://github.com/gvenzl/oci-oracle-free + + # {{{ install Oracle SQL server on ubuntu + + local DB_PASSWORD="BlahThat." + local ORACLE_VERSION="$1" # e.g. "23.5", "23.2", ... + docker pull gvenzl/oracle-free:$ORACLE_VERSION + docker run -d -p 1521:1521 -e ORACLE_PASSWORD="$DB_PASSWORD" gvenzl/oracle-free:$ORACLE_VERSION + + # }}} + + # {{{ instant client + wget https://download.oracle.com/otn_software/linux/instantclient/213000/instantclient-basiclite-linux.x64-21.3.0.0.0.zip + wget https://download.oracle.com/otn_software/linux/instantclient/213000/instantclient-sqlplus-linux.x64-21.3.0.0.0.zip + wget https://download.oracle.com/otn_software/linux/instantclient/213000/instantclient-odbc-linux.x64-21.3.0.0.0.zip + unzip instantclient-basiclite-linux.x64-21.3.0.0.0.zip + unzip instantclient-sqlplus-linux.x64-21.3.0.0.0.zip + unzip instantclient-odbc-linux.x64-21.3.0.0.0.zip + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD/instantclient_21_3 + + cd instantclient_21_3 + mkdir etc + cp /etc/odbcinst.ini etc/. + cp ~/.odbc.ini etc/odbc.ini + ./odbc_update_ini.sh . + sudo cp etc/odbcinst.ini /etc/ + + odbcinst -q -d + odbcinst -q -d -n "Oracle 21 ODBC driver" + + # test connection (interactively) with: + ./sqlplus scott/tiger@'(DESCRIPTION = (ADDRESS = (PROTOCOL = TCP)(HOST = db)(PORT = 1521)) (CONNECT_DATA = (SERVER = DEDICATED) (SERVICE_NAME = orcl)))' + + # show version to console + + # }}} +} + +setup_mysql() { + # install mysql server and its odbc driver + sudo apt install -y mysql-server # TODO: odbc driver +} + +case "$DBMS" in + "SQLite3") + setup_sqlite3 + ;; + "MS SQL Server 2019") + setup_sqlserver 2019 + ;; + "MS SQL Server 2022") + setup_sqlserver 2022 + ;; + "PostgreSQL") + setup_postgres + ;; + "Oracle") + setup_oracle + ;; + "MySQL") + setup_mysql + ;; + *) + echo "Unknown DBMS: $DBMS" + exit 1 + ;; +esac diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000..90ad687b --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,220 @@ +name: Build + +on: + merge_group: + push: + paths-ignore: + - 'mkdocs.yml' + - '*.sh' + branches: + - master + pull_request: + paths-ignore: + - 'docs/**' + - 'LICENSE.txt' + - 'mkdocs.yml' + - '*.md' + branches: + - master + +concurrency: + group: build-${{ github.ref }} + cancel-in-progress: true + +env: + CTEST_OUTPUT_ON_FAILURE: 1 + SCCACHE_GHA_ENABLED: "true" + +jobs: + + # {{{ Common checks + check_clang_format: + name: "Check C++ style" + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - name: Install clang + run: | + wget https://apt.llvm.org/llvm.sh + chmod +x llvm.sh + sudo ./llvm.sh 18 + sudo apt-get install clang-format-18 + - name: "Clang-format" + run: find ./src/ -name "*.cpp" -o -name "*.h" | xargs clang-format-18 --Werror --dry-run + + + check_clang_tidy: + name: "Check clang-tidy" + runs-on: ubuntu-24.04 + if: github.ref != 'refs/heads/master' + steps: + - uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: "ccache-ubuntu2404-clang-tidy" + max-size: 256M + - name: "update APT database" + run: sudo apt -q update + - name: Install clang + run: | + wget https://apt.llvm.org/llvm.sh + chmod +x llvm.sh + sudo ./llvm.sh 19 + sudo apt -qy install clang-tidy + - name: "install dependencies" + run: sudo apt install -y cmake ninja-build catch2 unixodbc-dev sqlite3 libsqlite3-dev libsqliteodbc + - name: Install GCC + run: sudo apt install -y g++-14 + - name: "cmake" + run: | + cmake -S . -B build -G Ninja \ + -D CMAKE_CXX_COMPILER="g++-14" + - name: "build" + run: cmake --build build + - name: "run clang-tidy" + run: find ./src/ -name "*.cpp" -o -name "*.h" | xargs -n 1 -P $(nproc) clang-tidy -format-style=file -p build + + # }}} + # {{{ Windows + windows: + name: "Windows" + runs-on: windows-latest + steps: + - uses: actions/checkout@v4 + - name: Run sccache-cache + uses: mozilla-actions/sccache-action@v0.0.3 + - name: "vcpkg: Install dependencies" + uses: lukka/run-vcpkg@v11.1 + id: runvcpkg + with: + vcpkgDirectory: ${{ runner.workspace }}/vcpkg + vcpkgGitCommitId: 80403036a665cb8fcc1a1b3e17593d20b03b2489 + - name: "Generate build files" + run: cmake --preset windows-cl-release + env: + VCPKG_ROOT: "${{ runner.workspace }}/vcpkg" + - name: "Build" + run: cmake --build --preset windows-cl-release + - name: "Test" + if: false # TODO: Install sqliteodbc first + run: ctest --preset windows-cl-release + # }}} + # {{{ Ubuntu build CC matrix + ubuntu_build_cc_matrix: + strategy: + fail-fast: false + matrix: + cxx: [23] + build_type: ["RelWithDebInfo"] + compiler: + [ + "GCC 14", + # "Clang 18", (does not seem to have std::expected<> just yet) + ] + name: "Ubuntu Linux 24.04 (${{ matrix.compiler }}, C++${{ matrix.cxx }})" + runs-on: ubuntu-24.04 + outputs: + id: "${{ matrix.compiler }} (C++${{ matrix.cxx }}, ${{ matrix.build_type }})" + steps: + - uses: actions/checkout@v4 + - name: ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: "ccache-ubuntu2404-${{ matrix.compiler }}-${{ matrix.cxx }}-${{ matrix.build_type }}" + max-size: 256M + - name: "update APT database" + run: sudo apt -q update + - name: "Set up output var: CC_VERSION" + id: extract_matrix + run: | + CC_VERSION=$( echo "${{ matrix.compiler }}" | awk '{ print $2; }') + echo "CC_VERSION=${CC_VERSION}" >> "$GITHUB_OUTPUT" + - name: "install dependencies" + run: sudo apt install -y cmake ninja-build catch2 unixodbc-dev sqlite3 libsqlite3-dev libsqliteodbc + - name: "inspect" + run: | + dpkg -L unixodbc-common + dpkg -L unixodbc-dev + - name: Install GCC + if: ${{ startsWith(matrix.compiler, 'GCC') }} + run: sudo apt install -y g++-${{ steps.extract_matrix.outputs.CC_VERSION }} + - name: Install Clang + if: ${{ startsWith(matrix.compiler, 'Clang') }} + run: sudo apt install -y clang-${{ steps.extract_matrix.outputs.CC_VERSION }} #libc++-dev libc++abi-dev + - name: "cmake" + run: | + CC_NAME=$(echo "${{ matrix.compiler }}" | awk '{ print tolower($1); }') + CC_VER=$( echo "${{ matrix.compiler }}" | awk '{ print $2; }') + test "${{ matrix.compiler }}" = "GCC 8" && EXTRA_CMAKE_FLAGS="$EXTRA_CMAKE_FLAGS -DPEDANTIC_COMPILER_WERROR=ON" + test "${CC_NAME}" = "gcc" && CC_EXE="g++" + if [[ "${CC_NAME}" = "clang" ]]; then + CC_EXE="clang++" + # CMAKE_CXX_FLAGS="-stdlib=libc++" + # CMAKE_EXE_LINKER_FLAGS="-stdlib=libc++ -lc++abi" + # EXTRA_CMAKE_FLAGS="$EXTRA_CMAKE_FLAGS -DENABLE_TIDY=ON" + # EXTRA_CMAKE_FLAGS="$EXTRA_CMAKE_FLAGS -DPEDANTIC_COMPILER_WERROR=OFF" + fi + cmake \ + $EXTRA_CMAKE_FLAGS \ + -DCMAKE_BUILD_TYPE="${{ matrix.build_type }}" \ + -DCMAKE_CXX_STANDARD=${{ matrix.cxx }} \ + -DCMAKE_CXX_COMPILER="${CC_EXE}-${CC_VER}" \ + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS}" \ + -DCMAKE_EXE_LINKER_FLAGS="${CMAKE_EXE_LINKER_FLAGS}" \ + -DCMAKE_INSTALL_PREFIX="/usr" \ + -DPEDANTIC_COMPILER_WERROR=OFF \ + --preset linux-gcc-release + - name: "build" + run: cmake --build --preset linux-gcc-release -- -j3 + - name: "tests" + run: ctest --preset linux-gcc-release + - name: "Move tests to root directory" + run: | + mv out/build/linux-gcc-release/src/tests/LightweightTest . + - name: "Upload unit tests" + if: ${{ matrix.compiler == 'GCC 14' && matrix.cxx == '23' }} + uses: actions/upload-artifact@v4 + with: + name: ubuntu2404-tests + path: | + LightweightTest + retention-days: 1 + + dbms_test_matrix: + strategy: + fail-fast: false + matrix: + database: + [ + "SQLite3", + "MS SQL Server 2019", + "MS SQL Server 2022", + "PostgreSQL" + # TODO: "Oracle" + # TODO: "MySQL" or "MariaDB" + ] + name: "Tests (${{ matrix.database }})" + runs-on: ubuntu-24.04 + needs: [ubuntu_build_cc_matrix] + env: + DBMS: "${{ matrix.database }}" + steps: + - uses: actions/checkout@v4 + - name: "Download unit test binaries" + uses: actions/download-artifact@v4 + with: + name: ubuntu2404-tests + - name: "Mark unit tests as executable" + run: chmod 0755 LightweightTest + - name: "Setup ${{ matrix.database }}" + id: setup + run: bash ./.github/prepare-test-run.sh "${{ matrix.database }}" + - name: "Dump SQL connection string" + run: echo "ODBC_CONNECTION_STRING=${{ steps.setup.outputs.ODBC_CONNECTION_STRING }}" + - name: "Run SQL Core tests" + run: ./LightweightTest --trace-sql --trace-model -s # --odbc-connection-string="${{ steps.setup.outputs.ODBC_CONNECTION_STRING }}" + env: + ODBC_CONNECTION_STRING: "${{ steps.setup.outputs.ODBC_CONNECTION_STRING }}" + + # }}} diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..0a18858f --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.cache/ +.vs/ +.vscode/ +Testing/ +Debug/ +Release/ +build/ +vcpkg_installed/ +CMakeSettings.json +compile_commands.json diff --git a/.vimspector.json b/.vimspector.json new file mode 100644 index 00000000..182c06df --- /dev/null +++ b/.vimspector.json @@ -0,0 +1,27 @@ +{ + "$schema": "https://puremourning.github.io/vimspector/schema/vimspector.schema.json#", + "configurations": { + "ModelTest": { + "adapter": "vscode-cpptools", + "configuration": { + "request": "launch", + "program": "${workspaceRoot}/out/build/linux-gcc-debug/tools/ddl2cpp", + "args": [ + "--connection-string", "DRIVER=SQLite3;DATABASE=file::memory:", + "--create-test-tables", + "--output", "blurb.hpp" + ], + "cwd": "${workspaceRoot}", + "externalConsole": true, + "stopAtEntry": false, + "MIMode": "gdb" + }, + "breakpoints": { + "exception": { + "caught": "Y", + "uncaught": "Y" + } + } + } + } +} diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..845fd491 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,32 @@ +cmake_minimum_required(VERSION 3.16 FATAL_ERROR) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") + +project(Lightweight VERSION 0.1.0 LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 23) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_COLOR_DIAGNOSTICS ON) + +include(ClangTidy) +include(PedanticCompiler) + +if(NOT WIN32 AND NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) + set(CMAKE_BUILD_TYPE RelWithDebInfo CACHE STRING "Choose the build mode." FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS Debug Release MinSizeRel RelWithDebInfo) +endif() + +if(DEFINED MSVC) + add_definitions(-D_USE_MATH_DEFINES) + add_definitions(-DNOMINMAX) + add_compile_options(/utf-8) + add_compile_options(/nologo) +endif() + +add_subdirectory(src/Lightweight) +add_subdirectory(src/tools) + +enable_testing() +add_subdirectory(src/tests) diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 00000000..a87b5ea9 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,88 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "common", + "hidden": true, + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "generator": "Ninja", + "cacheVariables": { + "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}", + "CMAKE_VERBOSE_MAKEFILE": "OFF" + } + }, + { + "name": "windows-common", + "hidden": true, + "inherits": ["common"], + "toolchainFile": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + }, + "cacheVariables": { + "CMAKE_SYSTEM_VERSION": "10.0" + } + }, + { + "name": "windows-cl-common", + "hidden": true, + "inherits": ["windows-common"], + "generator": "Visual Studio 17 2022", + "cacheVariables": { + "CMAKE_CXX_COMPILER": "cl" + } + }, + { "name": "windows-cl-debug", "inherits": ["windows-cl-common"], "displayName": "Windows - MSVC CL - Debug" }, + { "name": "windows-cl-release", "inherits": ["windows-cl-common"], "displayName": "Windows - MSVC CL - Release" }, + { + "name": "windows-clangcl-common", + "hidden": true, + "inherits": ["windows-common"], + "cacheVariables": { + "CMAKE_CXX_COMPILER": "clang-cl", + "CMAKE_C_COMPILER": "clang-cl" + } + }, + { "name": "windows-clangcl-debug", "inherits": ["windows-clangcl-common"], "displayName": "Windows - MSVC ClangCL - Debug", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, + { "name": "windows-clangcl-release","inherits": ["windows-clangcl-common"], "displayName": "Windows - MSVC ClangCL - Release", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, + { "name": "linux-common", "hidden": true, "inherits": ["common"], "condition": { "type": "equals", "lhs": "${hostSystemName}", "rhs": "Linux" } }, + { "name": "macos-common", "hidden": true, "inherits": ["common"], "condition": { "lhs": "${hostSystemName}", "type": "equals", "rhs": "Darwin" } }, + { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, + { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, + { "name": "pedantic", "hidden": true, "cacheVariables": { "PEDANTIC_COMPILER": "ON", "PEDANTIC_COMPILER_WERROR": "ON" } }, + { "name": "linux-common-release", "hidden": true, "inherits": "release", "cacheVariables": { "CMAKE_INSTALL_PREFIX": "/usr/local" } }, + { "name": "linux-clang", "hidden": true, "inherits": "linux-common", "cacheVariables": { "CMAKE_CXX_COMPILER": "clang++" } }, + { "name": "linux-gcc", "hidden": true, "inherits": "linux-common", "cacheVariables": { "CMAKE_CXX_COMPILER": "g++" } }, + { "name": "linux-clang-release", "displayName": "Linux (Clang) Release", "inherits": ["linux-clang", "linux-common-release"] }, + { "name": "linux-clang-debug", "displayName": "Linux (Clang) Debug", "inherits": ["linux-clang", "debug"] }, + { "name": "linux-gcc-release", "displayName": "Linux (GCC) Release", "inherits": ["linux-gcc", "linux-common-release"] }, + { "name": "linux-gcc-debug", "displayName": "Linux (GCC) Debug", "inherits": ["linux-gcc", "debug", "pedantic"] }, + { "name": "macos-release", "displayName": "MacOS Release", "inherits": ["macos-common", "release"] }, + { "name": "macos-debug", "displayName": "MacOS Debug", "inherits": ["macos-common", "debug"] } + ], + "buildPresets": [ + { "name": "windows-cl-debug", "displayName": "Windows - MSVC CL - Debug", "configurePreset": "windows-cl-debug", "configuration": "Debug" }, + { "name": "windows-cl-release", "displayName": "Windows - MSVC CL - Release", "configurePreset": "windows-cl-release", "configuration": "Release" }, + { "name": "windows-clangcl-debug", "displayName": "Windows - MSVC ClangCL - Debug", "configurePreset": "windows-clangcl-debug", "configuration": "Debug" }, + { "name": "windows-clangcl-release", "displayName": "Windows - MSVC ClangCL - Release", "configurePreset": "windows-clangcl-release", "configuration": "Release" }, + { "name": "linux-clang-debug", "displayName": "Linux - Clang - Debug", "configurePreset": "linux-clang-debug" }, + { "name": "linux-clang-release", "displayName": "Linux - Clang - RelWithDebInfo", "configurePreset": "linux-clang-release" }, + { "name": "linux-gcc-debug", "displayName": "Linux - GCC - Debug", "configurePreset": "linux-gcc-debug" }, + { "name": "linux-gcc-release", "displayName": "Linux - GCC - RelWithDebInfo", "configurePreset": "linux-gcc-release" }, + { "name": "macos-debug", "displayName": "MacOS - Debug", "configurePreset": "macos-debug" }, + { "name": "macos-release", "displayName": "MacOS - RelWithDebInfo", "configurePreset": "macos-release" } + ], + "testPresets": [ + { "name": "windows-cl-debug", "configurePreset": "windows-cl-debug", "output": {"outputOnFailure": true}, "execution": { "noTestsAction": "error", "stopOnFailure": true } }, + { "name": "windows-cl-release", "configurePreset": "windows-cl-release", "output": {"outputOnFailure": true}, "execution": { "noTestsAction": "error", "stopOnFailure": true } }, + { "name": "windows-clangcl-debug", "configurePreset": "windows-clangcl-debug", "output": {"outputOnFailure": true}, "execution": { "noTestsAction": "error", "stopOnFailure": true } }, + { "name": "windows-clangcl-release", "configurePreset": "windows-clangcl-release", "output": {"outputOnFailure": true}, "execution": { "noTestsAction": "error", "stopOnFailure": true } }, + { "name": "linux-gcc-debug", "configurePreset": "linux-gcc-debug", "output": {"outputOnFailure": true}, "execution": { "noTestsAction": "error", "stopOnFailure": true } }, + { "name": "linux-gcc-release", "configurePreset": "linux-gcc-release", "output": {"outputOnFailure": true}, "execution": { "noTestsAction": "error", "stopOnFailure": true } }, + { "name": "linux-clang-debug", "configurePreset": "linux-clang-debug", "output": {"outputOnFailure": true}, "execution": { "noTestsAction": "error", "stopOnFailure": true } }, + { "name": "linux-clang-release", "configurePreset": "linux-clang-release", "output": {"outputOnFailure": true}, "execution": { "noTestsAction": "error", "stopOnFailure": true } } + ] +} diff --git a/README.md b/README.md new file mode 100644 index 00000000..ca7ec316 --- /dev/null +++ b/README.md @@ -0,0 +1,55 @@ +# Lightweight, SQL and ORM data mapper API for C++23 + +**Lightweight** is a thin and modern C++ ODBC wrapper for **easy** and **fast** raw database access. + +## Features + +- **Easy to use** - simple, expressive and intuitive API +- **Performance** - do as little as possible, and as much as necessary - efficiency is key +- **Extensible** - support for custom data types for writing to and reading from columns +- **Exception safe** - no need to worry about resource management +- **Open Collaboration** - code directly integrated into the main project +- **Monad-like** - simple error handling with `std::expected`-like API + +## Non-Goals + +- Feature creeping (ODBC is a huge API, we are not going to wrap everything) +- No intend to support non-ODBC connectors directly, in order to keep the codebase simple and focused + +## C++ Language requirements + +This library a little bit of more modern C++ language and library features in order to be more expressive and efficient. + +- C++23 (`std::expected`, `std::stacktrace`, lambda templates expressions) +- C++20 (`std::source_location`, `std::error_code`, `operator<=>`, `std::format()`, designated initializers, concepts, ranges) +- C++17 (fold expressions, `std::string_view`, `std::optional`, `std::variant`, `std::apply`) + +## Supported Databases + +- Microsoft SQL +- PostgreSQL +- SQLite +- Oracle database (untested) + +## Using SQLite for testing on Windows operating system + +You need to have the SQLite3 ODBC driver for SQLite installed. + +- ODBC driver download URL: http://www.ch-werner.de/sqliteodbc/ +- Example connection string: `DRIVER={SQLite3 ODBC Driver};Database=file::memory:` + +### SQLite ODBC driver installation on other operating systems + +```sh +# Fedora Linux +sudo dnf install sqliteodbc + +# Ubuntu Linux +sudo apt install sqliteodbc + +# macOS +arch -arm64 brew install sqliteodbc +``` + +- sqliteODBC Documentation: http://www.ch-werner.de/sqliteodbc/html/index.html +- Example connection string: `DRIVER=SQLite3;Database=file::memory:` diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..e62ac71e --- /dev/null +++ b/TODO.md @@ -0,0 +1,5 @@ + +### TODO items + +- [ ] `SqlStatement::BindINputParameters(Params... params)` to unfold the pack with index properly incremented from 1... +- [x] check for more use of `ExecuteDirectScalar` diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake new file mode 100644 index 00000000..65c974b7 --- /dev/null +++ b/cmake/ClangTidy.cmake @@ -0,0 +1,15 @@ + +option(ENABLE_TIDY "Enable clang-tidy [default: OFF]" OFF) +if(ENABLE_TIDY) + find_program(CLANG_TIDY_EXE + NAMES clang-tidy + DOC "Path to clang-tidy executable") + if(NOT CLANG_TIDY_EXE) + message(STATUS "[clang-tidy] Not found.") + else() + message(STATUS "[clang-tidy] found: ${CLANG_TIDY_EXE}") + set(CMAKE_CXX_CLANG_TIDY "${CLANG_TIDY_EXE}") + endif() +else() + message(STATUS "[clang-tidy] Disabled.") +endif() diff --git a/cmake/PedanticCompiler.cmake b/cmake/PedanticCompiler.cmake new file mode 100644 index 00000000..b59c6bc1 --- /dev/null +++ b/cmake/PedanticCompiler.cmake @@ -0,0 +1,79 @@ +include(CheckCXXCompilerFlag) +function(try_add_compile_options FLAG) + # Remove leading - or / from the flag name. + string(REGEX REPLACE "^[-/]" "" name ${FLAG}) + # Deletes any ':' because it's invalid variable names. + string(REGEX REPLACE ":" "" name ${name}) + check_cxx_compiler_flag(${FLAG} ${name}) + if(${name}) + message(STATUS "Adding compiler flag: ${FLAG}.") + add_compile_options(${FLAG}) + else() + message(STATUS "Adding compiler flag: ${FLAG} failed.") + endif() + + # If the optional argument passed, store the result there. + if(ARGV1) + set(${ARGV1} ${name} PARENT_SCOPE) + endif() +endfunction() + +option(PEDANTIC_COMPILER "Compile the project with almost all warnings turned on." OFF) +option(PEDANTIC_COMPILER_WERROR "Enables -Werror to force warnings to be treated as errors." OFF) + +# Always show diagnostics in colored output. +try_add_compile_options(-fdiagnostics-color=always) + +if(${PEDANTIC_COMPILER}) + if(("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") OR ("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")) + message(STATUS "Enabling pedantic compiler options: yes") + # TODO: check https://github.com/lefticus/cppbestpractices/blob/master/02-Use_the_Tools_Available.md#compilers + try_add_compile_options(-Qunused-arguments) + try_add_compile_options(-Wall) + #try_add_compile_options(-Wconversion) + try_add_compile_options(-Wduplicate-enum) + try_add_compile_options(-Wduplicated-cond) + try_add_compile_options(-Wextra) + try_add_compile_options(-Wextra-semi) + try_add_compile_options(-Wfinal-dtor-non-final-class) + try_add_compile_options(-Wimplicit-fallthrough) + try_add_compile_options(-Wlogical-op) + try_add_compile_options(-Wmissing-declarations) + try_add_compile_options(-Wnewline-eof) + try_add_compile_options(-Wno-unknown-attributes) + try_add_compile_options(-Wno-unknown-pragmas) + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU") + # -Wdangling-reference will generate false positives on recent GCC versions. + # See https://gcc.gnu.org/git/gitweb.cgi?p=gcc.git;h=6b927b1297e66e26e62e722bf15c921dcbbd25b9 + try_add_compile_options(-Wno-dangling-reference) + else() + try_add_compile_options(-Wdangling-reference) + endif() + try_add_compile_options(-Wnull-dereference) + try_add_compile_options(-Wpessimizing-move) + try_add_compile_options(-Wredundant-move) + #try_add_compile_options(-Wsign-conversion) + try_add_compile_options(-Wsuggest-destructor-override) + try_add_compile_options(-pedantic) + else() + message(STATUS "Enabling pedantic compiler options: unsupported platform") + endif() +else() + message(STATUS "Enabling pedantic compiler options: no") +endif() + +if(${PEDANTIC_COMPILER_WERROR}) + try_add_compile_options(-Werror) + + # Don't complain here. That's needed for bitpacking (codepoint_properties) in libunicode dependency. + try_add_compile_options(-Wno-error=c++20-extensions) + try_add_compile_options(-Wno-c++20-extensions) + + # Not sure how to work around these. + try_add_compile_options(-Wno-error=class-memaccess) + try_add_compile_options(-Wno-class-memaccess) + + # TODO: Should be addressed. + try_add_compile_options(-Wno-error=missing-declarations) + try_add_compile_options(-Wno-missing-declarations) +endif() diff --git a/src/.clang-format b/src/.clang-format new file mode 100644 index 00000000..8ee01688 --- /dev/null +++ b/src/.clang-format @@ -0,0 +1,85 @@ +--- +BasedOnStyle: Microsoft +AccessModifierOffset: '-2' +AlignAfterOpenBracket: Align +AlignConsecutiveMacros: 'true' +AlignConsecutiveDeclarations: 'false' +AlignEscapedNewlines: Left +AlignOperands: 'true' +AlignTrailingComments: 'true' +AllowAllArgumentsOnNextLine: 'true' +AllowAllConstructorInitializersOnNextLine: 'true' +AllowAllParametersOfDeclarationOnNextLine: 'true' +AllowShortBlocksOnASingleLine: 'false' +AllowShortCaseLabelsOnASingleLine: 'false' +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: Inline +AllowShortLoopsOnASingleLine: 'false' +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: 'false' +AlwaysBreakTemplateDeclarations: 'Yes' +BinPackArguments: 'false' +BinPackParameters: 'false' +BreakBeforeBinaryOperators: NonAssignment +BreakBeforeBraces: Custom +BreakBeforeTernaryOperators: 'true' +BreakConstructorInitializers: AfterColon +BreakInheritanceList: AfterColon +BreakStringLiterals: 'true' +ColumnLimit: '120' +CompactNamespaces: 'false' +PackConstructorInitializers: Never +ConstructorInitializerIndentWidth: '4' +ContinuationIndentWidth: '4' +Cpp11BracedListStyle: 'false' +DerivePointerAlignment: 'false' +FixNamespaceComments: 'true' +IncludeBlocks: Regroup +IndentCaseLabels: true +IndentPPDirectives: BeforeHash +IndentWidth: '4' +IndentWrappedFunctionNames: 'false' +Language: Cpp +MaxEmptyLinesToKeep: '1' +NamespaceIndentation: Inner +PenaltyBreakAssignment: '0' +PointerAlignment: Left +ReflowComments: 'true' +SortIncludes: 'true' +SortUsingDeclarations: 'true' +SpaceAfterCStyleCast: 'true' +SpaceAfterLogicalNot: 'false' +SpaceAfterTemplateKeyword: 'true' +SpaceBeforeAssignmentOperators: 'true' +SpaceBeforeCpp11BracedList: 'true' +SpaceBeforeCtorInitializerColon: 'false' +SpaceBeforeInheritanceColon: 'false' +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: 'false' +SpaceInEmptyParentheses: 'false' +SpacesInAngles: 'false' +SpacesInCStyleCastParentheses: 'false' +SpacesInContainerLiterals: 'false' +SpacesInParentheses: 'false' +SpacesInSquareBrackets: 'false' +Standard: Cpp11 +TabWidth: '4' +UseTab: Never +IncludeCategories: + - Regex: '^".*"' + Priority: 0 + - Regex: '^<(Lightweight)/' + Priority: 10 + - Regex: '^' + Priority: 81 + - Regex: '<[[:alnum:]_]+\.h>' + Priority: 82 + - Regex: '.*' + Priority: 99 diff --git a/src/.clang-tidy b/src/.clang-tidy new file mode 100644 index 00000000..956587b5 --- /dev/null +++ b/src/.clang-tidy @@ -0,0 +1,105 @@ +--- +Checks: >- + -*, + bugprone-*, + -bugprone-branch-clone, + -bugprone-easily-swappable-parameters, + -bugprone-exception-escape, + -bugprone-implicit-widening-of-multiplication-result, + -bugprone-reserved-identifier, + -bugprone-suspicious-include, + -bugprone-unchecked-optional-access, + clang-analyzer-core.*, + clang-analyzer-cplusplus.*, + clang-analyzer-deadcode.*, + clang-analyzer-nullability.*, + clang-analyzer-optin.cplusplus.*, + -clang-analyzer-optin.cplusplus.UninitializedObject, + clang-analyzer-optin.performance.*, + clang-analyzer-optin.portability.*, + clang-analyzer-security.*, + clang-analyzer-unix.*, + clang-diagnostic-*, + cppcoreguidelines-*, + cppcoreguidelines-*, + -cppcoreguidelines-avoid-c-arrays, + -cppcoreguidelines-avoid-const-or-ref-data-members, + -cppcoreguidelines-avoid-do-while, + -cppcoreguidelines-avoid-magic-numbers, + -cppcoreguidelines-avoid-non-const-global-variables, + -cppcoreguidelines-macro-usage, + -cppcoreguidelines-no-malloc, + -cppcoreguidelines-non-private-member-variables-in-classes, + -cppcoreguidelines-owning-memory, + -cppcoreguidelines-prefer-member-initializer, + -cppcoreguidelines-pro-bounds-array-to-pointer-decay, + -cppcoreguidelines-pro-bounds-constant-array-index, + -cppcoreguidelines-pro-bounds-pointer-arithmetic, + -cppcoreguidelines-pro-type-const-cast, + -cppcoreguidelines-pro-type-cstyle-cast, + -cppcoreguidelines-pro-type-reinterpret-cast, + -cppcoreguidelines-pro-type-static-cast-downcast, + -cppcoreguidelines-pro-type-union-access, + -cppcoreguidelines-pro-type-vararg, + -cppcoreguidelines-special-member-functions, + modernize-*, + -modernize-avoid-bind, + -modernize-avoid-c-arrays, + -modernize-return-braced-init-list, + -modernize-use-bool-literals, + -modernize-use-trailing-return-type, + performance-*, + -performance-no-int-to-ptr, + readability-*, + -readability-avoid-unconditional-preprocessor-if, + -readability-braces-around-statements, + -readability-container-contains, + -readability-else-after-return, + -readability-function-cognitive-complexity, + -readability-identifier-length, + -readability-implicit-bool-conversion, + -readability-magic-numbers, + -readability-named-parameter, + -readability-redundant-access-specifiers, + -readability-simplify-boolean-expr, + -readability-uppercase-literal-suffix, + -readability-use-anyofallof, + misc-const-correctness, +WarningsAsErrors: >- + modernize-use-nullptr, +UseColor: true +HeaderFilterRegex: 'src/' +FormatStyle: none +CheckOptions: + - key: bugprone-easily-swappable-parameters.MinimumLength + value: '3' + - key: cert-dcl16-c.NewSuffixes + value: 'L;LL;LU;LLU' + - key: cert-oop54-cpp.WarnOnlyIfThisHasSuspiciousField + value: '0' + - key: cppcoreguidelines-explicit-virtual-functions.IgnoreDestructors + value: '1' + - key: cppcoreguidelines-non-private-member-variables-in-classes.IgnoreClassesWithAllMemberVariablesBeingPublic + value: '1' + - key: google-readability-braces-around-statements.ShortStatementLines + value: '1' + - key: google-readability-function-size.StatementThreshold + value: '800' + - key: google-readability-namespace-comments.ShortNamespaceLines + value: '10' + - key: google-readability-namespace-comments.SpacesBeforeComments + value: '2' + - key: modernize-loop-convert.MaxCopySize + value: '16' + - key: modernize-loop-convert.MinConfidence + value: reasonable + - key: modernize-loop-convert.NamingStyle + value: CamelCase + - key: modernize-pass-by-value.IncludeStyle + value: llvm + - key: modernize-replace-auto-ptr.IncludeStyle + value: llvm + - key: modernize-use-nullptr.NullMacros + value: 'NULL' + - key: modernize-use-default-member-init.UseAssignment + value: '1' diff --git a/src/Lightweight/.editorconfig b/src/Lightweight/.editorconfig new file mode 100644 index 00000000..87cb9d48 --- /dev/null +++ b/src/Lightweight/.editorconfig @@ -0,0 +1,12 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +insert_final_newline = true +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true + +[*.md] +indent_size = 2 diff --git a/src/Lightweight/.gitignore b/src/Lightweight/.gitignore new file mode 100644 index 00000000..811b4571 --- /dev/null +++ b/src/Lightweight/.gitignore @@ -0,0 +1,2 @@ +Testing/ +build/ diff --git a/src/Lightweight/CMakeLists.txt b/src/Lightweight/CMakeLists.txt new file mode 100644 index 00000000..46ef5f0d --- /dev/null +++ b/src/Lightweight/CMakeLists.txt @@ -0,0 +1,85 @@ +cmake_minimum_required(VERSION 3.16 FATAL_ERROR) + +project(Lightweight LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 23) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_COLOR_DIAGNOSTICS ON) + +if(NOT WIN32) + find_package(SQLite3 REQUIRED) + include_directories(${SQLITE3_INCLUDE_DIR}) + + # find unixODBC via pkg-config + find_package(PkgConfig REQUIRED) + pkg_check_modules(ODBC REQUIRED odbc) +endif() + +set(HEADER_FILES + Model/Associations/BelongsTo.hpp + Model/Associations/HasMany.hpp + Model/Associations/HasOne.hpp + Model/ColumnType.hpp + Model/Detail.hpp + Model/Logger.hpp + Model/Record.hpp + Model/Relation.hpp + Model/StringLiteral.hpp + Model/Utils.hpp + SqlComposedQuery.hpp + SqlConcepts.hpp + SqlConnectInfo.hpp + SqlConnection.hpp + SqlError.hpp + SqlLogger.hpp + SqlQueryFormatter.hpp + SqlSchema.hpp + SqlScopedTraceLogger.hpp + SqlStatement.hpp +) + +set(SOURCE_FILES + Model/AbstractRecord.cpp + Model/Logger.cpp + SqlComposedQuery.cpp + SqlConnectInfo.cpp + SqlConnection.cpp + SqlError.cpp + SqlLogger.cpp + SqlQueryFormatter.cpp + SqlSchema.cpp + SqlStatement.cpp + SqlTransaction.cpp +) + +add_library(Lightweight STATIC) +add_library(Lightweight::Lightweight ALIAS Lightweight) +target_compile_features(Lightweight PUBLIC cxx_std_23) +target_sources(Lightweight PRIVATE ${SOURCE_FILES}) +#target_sources(Lightweight PUBLIC ${HEADER_FILES}) + +if(CLANG_TIDY_EXE) + set_target_properties(Lightweight PROPERTIES CXX_CLANG_TIDY "${CLANG_TIDY_EXE}") +endif() + +# target_include_directories(Lightweight PUBLIC $) +target_include_directories(Lightweight PUBLIC + $ + $ +) + +if(MSVC) + target_compile_options(Lightweight PRIVATE /W4 /WX) + target_compile_options(Lightweight PRIVATE /MP) + target_link_libraries(Lightweight PUBLIC odbc32) +else() + target_compile_options(Lightweight PRIVATE -Wall -Wextra -pedantic -Werror) + target_compile_options(Lightweight PUBLIC ${ODBC_CFLAGS}) + target_link_libraries(Lightweight PUBLIC ${ODBC_LDFLAGS}) + + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + target_link_libraries(Lightweight PUBLIC stdc++exp) # GCC >= 14 + endif() +endif() diff --git a/src/Lightweight/Model/AbstractField.hpp b/src/Lightweight/Model/AbstractField.hpp new file mode 100644 index 00000000..d973b9de --- /dev/null +++ b/src/Lightweight/Model/AbstractField.hpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../SqlError.hpp" +#include "ColumnType.hpp" +#include "Detail.hpp" + +#include +#include + +class SqlStatement; + +namespace Model +{ +struct SqlColumnNameView +{ + std::string_view name; +}; +} // namespace Model + +template <> +struct std::formatter: std::formatter +{ + auto format(Model::SqlColumnNameView const& column, format_context& ctx) const -> format_context::iterator + { + return std::formatter::format(std::format("\"{}\"", column.name), ctx); + } +}; + +namespace Model +{ + +enum class FieldValueRequirement : uint8_t +{ + NULLABLE, + NOT_NULL, +}; + +constexpr inline FieldValueRequirement SqlNullable = FieldValueRequirement::NULLABLE; +constexpr inline FieldValueRequirement SqlNotNullable = FieldValueRequirement::NULLABLE; + +struct AbstractRecord; + +// Base class for all fields in a table row (Record). +class AbstractField +{ + public: + AbstractField(AbstractRecord& record, + SQLSMALLINT index, + std::string_view name, + SqlColumnType type, + FieldValueRequirement requirement): + m_record { &record }, + m_index { index }, + m_name { name }, + m_type { type }, + m_requirement { requirement } + { + } + + AbstractField() = delete; + AbstractField(AbstractField&&) = default; + AbstractField& operator=(AbstractField&&) = default; + AbstractField(AbstractField const&) = delete; + AbstractField& operator=(AbstractField const&) = delete; + virtual ~AbstractField() = default; + + // Returns the syntax for the SQL constraint specification for this field, if any, otherwise an empty string. + [[nodiscard]] virtual std::string SqlConstraintSpecifier() const + { + return ""; + } + + [[nodiscard]] virtual std::string InspectValue() const = 0; + virtual void BindInputParameter(SQLSMALLINT parameterIndex, SqlStatement& stmt) const = 0; + virtual void BindOutputColumn(SqlStatement& stmt) = 0; + virtual void BindOutputColumn(SQLSMALLINT outputIndex, SqlStatement& stmt) = 0; + virtual void LoadValueFrom(AbstractField& other) = 0; + + // clang-format off + AbstractRecord& GetRecord() noexcept { return *m_record; } + [[nodiscard]] AbstractRecord const& GetRecord() const noexcept { return *m_record; } + void SetRecord(AbstractRecord& record) noexcept { m_record = &record; } + [[nodiscard]] bool IsModified() const noexcept { return m_modified; } + void SetModified(bool value) noexcept { m_modified = value; } + [[nodiscard]] SQLSMALLINT Index() const noexcept { return m_index; } + [[nodiscard]] SqlColumnNameView Name() const noexcept { return m_name; } + [[nodiscard]] SqlColumnType Type() const noexcept { return m_type; } + [[nodiscard]] bool IsNullable() const noexcept { return m_requirement == FieldValueRequirement::NULLABLE; } + [[nodiscard]] bool IsRequired() const noexcept { return m_requirement == FieldValueRequirement::NOT_NULL; } + // clang-format on + + private: + AbstractRecord* m_record; + SQLSMALLINT m_index; + SqlColumnNameView m_name; + SqlColumnType m_type; + FieldValueRequirement m_requirement; + bool m_modified = false; +}; + +} // namespace Model diff --git a/src/Lightweight/Model/AbstractRecord.cpp b/src/Lightweight/Model/AbstractRecord.cpp new file mode 100644 index 00000000..616916ed --- /dev/null +++ b/src/Lightweight/Model/AbstractRecord.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +#include "AbstractField.hpp" +#include "AbstractRecord.hpp" + +#include +#include +#include +#include + +namespace Model +{ + +std::string AbstractRecord::Inspect() const noexcept +{ + if (!m_data) + return "UNAVAILABLE"; + + detail::StringBuilder result; + + // Reserve enough space for the output string (This is merely a guess, but it's better than nothing) + result.output.reserve(TableName().size() + AllFields().size() * 32); + + result << "#<" << TableName() << ": id=" << Id().value; + for (auto const* field: AllFields()) + result << ", " << field->Name() << "=" << field->InspectValue(); + result << ">"; + + return *result; +} + +void AbstractRecord::SetModified(bool value) noexcept +{ + for (AbstractField* field: m_data->fields) + field->SetModified(value); +} + +bool AbstractRecord::IsModified() const noexcept +{ + return std::ranges::any_of(m_data->fields, [](AbstractField* field) { return field->IsModified(); }); +} + +AbstractRecord::FieldList AbstractRecord::GetModifiedFields() const noexcept +{ + FieldList result; + std::ranges::copy_if(m_data->fields, std::back_inserter(result), [](auto* field) { return field->IsModified(); }); + return result; +} + +void AbstractRecord::SortFieldsByIndex() noexcept +{ + std::sort(m_data->fields.begin(), m_data->fields.end(), [](auto a, auto b) { return a->Index() < b->Index(); }); +} + +std::vector AbstractRecord::AllFieldNames() const +{ + std::vector columnNames; + columnNames.resize(1 + m_data->fields.size()); + columnNames[0] = PrimaryKeyName(); + for (auto const* field: AllFields()) + columnNames[field->Index() - 1] = field->Name().name; + return columnNames; +} + +} // namespace Model diff --git a/src/Lightweight/Model/AbstractRecord.hpp b/src/Lightweight/Model/AbstractRecord.hpp new file mode 100644 index 00000000..aea90d69 --- /dev/null +++ b/src/Lightweight/Model/AbstractRecord.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "RecordId.hpp" + +#include +#include +#include +#include + +namespace Model +{ + +class AbstractField; + +struct SqlColumnIndex +{ + size_t value; +}; + +// Base class for every SqlModel. +struct AbstractRecord +{ + public: + AbstractRecord(std::string_view tableName, std::string_view primaryKey, RecordId id): + m_data { std::make_unique(tableName, primaryKey, id) } + { + } + + AbstractRecord() = delete; + AbstractRecord(AbstractRecord const&) = delete; + AbstractRecord(AbstractRecord&& other) noexcept: + AbstractRecord { other.m_data->tableName, other.m_data->primaryKeyName, other.m_data->id } + { + } + + AbstractRecord& operator=(AbstractRecord const&) = delete; + AbstractRecord& operator=(AbstractRecord&&) = delete; + virtual ~AbstractRecord() = default; + + // Returns a human readable string representation of this model. + [[nodiscard]] std::string Inspect() const noexcept; + + // clang-format off + [[nodiscard]] std::string_view TableName() const noexcept { return m_data->tableName; } // TODO: make this statically accessible from Record<> as well + [[nodiscard]] std::string_view PrimaryKeyName() const noexcept { return m_data->primaryKeyName; } // TODO: make this statically accessible from Record<> as well + [[nodiscard]] RecordId Id() const noexcept { return m_data->id; } + + RecordId& MutableId() noexcept { return m_data->id; } + + void RegisterField(AbstractField& field) noexcept { m_data->fields.push_back(&field); } + + void UnregisterField(AbstractField const& field) noexcept + { + if (!m_data) + return; + // remove field by rotating it to the end and then popping it + auto it = std::ranges::find(m_data->fields, &field); + if (it != m_data->fields.end()) + { + std::rotate(it, std::next(it), m_data->fields.end()); + m_data->fields.pop_back(); + } + } + + [[nodiscard]] AbstractField const& GetField(SqlColumnIndex index) const noexcept { return *m_data->fields[index.value]; } + AbstractField& GetField(SqlColumnIndex index) noexcept { return *m_data->fields[index.value]; } + // clang-format on + + void SetModified(bool value) noexcept; + + [[nodiscard]] bool IsModified() const noexcept; + + void SortFieldsByIndex() noexcept; + + using FieldList = std::vector; + + [[nodiscard]] FieldList GetModifiedFields() const noexcept; + + [[nodiscard]] FieldList const& AllFields() const noexcept + { + return m_data->fields; + } + + [[nodiscard]] std::vector AllFieldNames() const; + + protected: + struct Data + { + std::string_view tableName; // Should be const, but we want to allow move semantics + std::string_view primaryKeyName; // Should be const, but we want to allow move semantics + RecordId id {}; + + bool modified = false; + FieldList fields; + }; + std::unique_ptr m_data; +}; + +} // namespace Model + +template + requires std::derived_from +struct std::formatter: std::formatter +{ + auto format(D const& record, format_context& ctx) const -> format_context::iterator + { + return formatter::format(record.Inspect(), ctx); + } +}; diff --git a/src/Lightweight/Model/All.hpp b/src/Lightweight/Model/All.hpp new file mode 100644 index 00000000..c99551d7 --- /dev/null +++ b/src/Lightweight/Model/All.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include "Associations/BelongsTo.hpp" +#include "Associations/HasMany.hpp" +#include "Associations/HasManyThrough.hpp" +#include "Associations/HasOne.hpp" +#include "Associations/HasOneThrough.hpp" +#include "Field.hpp" +#include "Logger.hpp" +#include "Record.hpp" +#include "RecordId.hpp" +#include "Utils.hpp" diff --git a/src/Lightweight/Model/Associations/BelongsTo.hpp b/src/Lightweight/Model/Associations/BelongsTo.hpp new file mode 100644 index 00000000..40857ac4 --- /dev/null +++ b/src/Lightweight/Model/Associations/BelongsTo.hpp @@ -0,0 +1,283 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../../SqlStatement.hpp" +#include "../AbstractField.hpp" +#include "../AbstractRecord.hpp" +#include "../ColumnType.hpp" +#include "../RecordId.hpp" +#include "../StringLiteral.hpp" + +#include +#include +#include + +namespace Model +{ + +template +struct Record; + +// Represents a column in a table that is a foreign key to another table. +template +class BelongsTo final: public AbstractField +{ + public: + constexpr static inline SQLSMALLINT ColumnIndex { TheColumnIndex }; + constexpr static inline std::string_view ColumnName { TheForeignKeyName.value }; + + explicit BelongsTo(AbstractRecord& record); + BelongsTo(BelongsTo const& other); + explicit BelongsTo(AbstractRecord& record, BelongsTo&& other); + BelongsTo& operator=(RecordId modelId); + BelongsTo& operator=(OtherRecord const& model); + ~BelongsTo() override = default; + + OtherRecord* operator->(); + OtherRecord& operator*(); + + [[nodiscard]] std::string SqlConstraintSpecifier() const override; + + [[nodiscard]] std::string InspectValue() const override; + void BindInputParameter(SQLSMALLINT parameterIndex, SqlStatement& stmt) const override; + void BindOutputColumn(SqlStatement& stmt) override; + void BindOutputColumn(SQLSMALLINT index, SqlStatement& stmt) override; + void LoadValueFrom(AbstractField& other) override; + + auto operator<=>(BelongsTo const& other) const noexcept; + + template + bool operator==(BelongsTo const& other) const noexcept; + + template + bool operator!=(BelongsTo const& other) const noexcept; + + void Load() noexcept; + + private: + void RequireLoaded(); + + RecordId m_value {}; + std::shared_ptr m_otherRecord; + + // We decided to use shared_ptr here, because we do not want to require to know the size of the OtherRecord + // at declaration time. +}; + +// {{{ BelongsTo<> implementation + +template +BelongsTo::BelongsTo(AbstractRecord& record): + AbstractField { + record, TheColumnIndex, TheForeignKeyName.value, ColumnTypeOf, TheRequirement, + } +{ + record.RegisterField(*this); +} + +template +BelongsTo::BelongsTo(BelongsTo const& other): + AbstractField { + const_cast(other).GetRecord(), + TheColumnIndex, + TheForeignKeyName.value, + ColumnTypeOf, + TheRequirement, + }, + m_value { other.m_value } +{ + GetRecord().RegisterField(*this); +} + +template +BelongsTo::BelongsTo(AbstractRecord& record, + BelongsTo&& other): + AbstractField { std::move(static_cast(other)) }, + m_value { std::move(other.m_value) } +{ + record.RegisterField(*this); +} + +template +BelongsTo& +BelongsTo::operator=(RecordId modelId) +{ + SetModified(true); + m_value = modelId; + return *this; +} + +template +BelongsTo& +BelongsTo::operator=(OtherRecord const& model) +{ + SetModified(true); + m_value = model.Id(); + return *this; +} + +template +inline OtherRecord* BelongsTo::operator->() +{ + RequireLoaded(); + return &*m_otherRecord; +} + +template +inline OtherRecord& BelongsTo::operator*() +{ + RequireLoaded(); + return *m_otherRecord; +} + +template +std::string BelongsTo::SqlConstraintSpecifier() const +{ + auto const otherRecord = OtherRecord {}; + // TODO: Move the syntax into SqlTraits, as a parametrized member function + return std::format("FOREIGN KEY ({}) REFERENCES {}({}) ON DELETE CASCADE", + ColumnName, + otherRecord.TableName(), + otherRecord.PrimaryKeyName()); +} + +template +inline std::string BelongsTo::InspectValue() const +{ + return std::to_string(m_value.value); +} + +template +inline void BelongsTo::BindInputParameter( + SQLSMALLINT parameterIndex, SqlStatement& stmt) const +{ + return stmt.BindInputParameter(parameterIndex, m_value.value); +} + +template +inline void BelongsTo::BindOutputColumn( + SqlStatement& stmt) +{ + return stmt.BindOutputColumn(TheColumnIndex, &m_value.value); +} + +template +inline void BelongsTo::BindOutputColumn( + SQLSMALLINT outputIndex, SqlStatement& stmt) +{ + return stmt.BindOutputColumn(outputIndex, &m_value.value); +} + +template +void BelongsTo::LoadValueFrom(AbstractField& other) +{ + assert(Type() == other.Type()); + m_value = std::move(static_cast(other).m_value); + m_otherRecord.reset(); +} + +template +inline auto BelongsTo::operator<=>( + BelongsTo const& other) const noexcept +{ + return m_value <=> other.m_value; +} + +template +template +inline bool BelongsTo::operator==( + BelongsTo const& other) const noexcept +{ + return m_value == other.m_value; +} + +template +template +inline bool BelongsTo::operator!=( + BelongsTo const& other) const noexcept +{ + return m_value == other.m_value; +} + +template +void BelongsTo::Load() noexcept +{ + if (m_otherRecord) + return; + + auto otherRecord = OtherRecord::Find(m_value); + if (otherRecord.has_value()) + m_otherRecord = std::make_shared(std::move(otherRecord.value())); +} + +template +inline void BelongsTo::RequireLoaded() +{ + if (!m_otherRecord) + { + Load(); + if (!m_otherRecord) + throw std::runtime_error("BelongsTo::RequireLoaded(): Record not found"); + } +} + +// }}} + +} // namespace Model diff --git a/src/Lightweight/Model/Associations/HasMany.hpp b/src/Lightweight/Model/Associations/HasMany.hpp new file mode 100644 index 00000000..ae33bac4 --- /dev/null +++ b/src/Lightweight/Model/Associations/HasMany.hpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../../SqlError.hpp" +#include "../AbstractRecord.hpp" +#include "../Logger.hpp" +#include "../StringLiteral.hpp" + +#include + +namespace Model +{ + +template +class HasMany +{ + public: + explicit HasMany(AbstractRecord& parent); + HasMany(AbstractRecord& record, HasMany&& other) noexcept; + + [[nodiscard]] bool IsEmpty() const; + [[nodiscard]] size_t Count() const; + + std::vector& All(); + + OtherRecord& At(size_t index); + OtherRecord& operator[](size_t index); + + [[nodiscard]] bool IsLoaded() const noexcept; + void Load(); + void Reload(); + + private: + bool RequireLoaded(); + + bool m_loaded = false; + AbstractRecord* m_record; + std::vector m_models; +}; + +// {{{ HasMany<> implementation + +template +HasMany::HasMany(AbstractRecord& parent): + m_record { &parent } +{ +} + +template +HasMany::HasMany(AbstractRecord& record, HasMany&& other) noexcept: + m_loaded { other.m_loaded }, + m_record { &record }, + m_models { std::move(other.m_models) } +{ +} + +template +void HasMany::Load() +{ + if (m_loaded) + return; + + m_models = OtherRecord::Where(*ForeignKeyName, m_record->Id()).All(); + m_loaded = true; +} + +template +void HasMany::Reload() +{ + m_loaded = false; + m_models.clear(); + return Load(); +} + +template +bool HasMany::IsEmpty() const +{ + return Count() == 0; +} + +template +size_t HasMany::Count() const +{ + if (m_loaded) + return m_models.size(); + + SqlStatement stmt; + + auto const sqlQueryString = std::format( + "SELECT COUNT(*) FROM {} WHERE {} = {}", OtherRecord().TableName(), *ForeignKeyName, *m_record->Id()); + auto const scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, {}); + return stmt.ExecuteDirectScalar(sqlQueryString).value(); +} + +template +inline std::vector& HasMany::All() +{ + RequireLoaded(); + return m_models; +} + +template +inline OtherRecord& HasMany::At(size_t index) +{ + RequireLoaded(); + return m_models.at(index); +} + +template +inline OtherRecord& HasMany::operator[](size_t index) +{ + RequireLoaded(); + return m_models[index]; +} + +template +inline bool HasMany::IsLoaded() const noexcept +{ + return m_loaded; +} + +template +inline bool HasMany::RequireLoaded() +{ + if (!m_loaded) + Load(); + + return m_loaded; +} + +// }}} + +} // namespace Model diff --git a/src/Lightweight/Model/Associations/HasManyThrough.hpp b/src/Lightweight/Model/Associations/HasManyThrough.hpp new file mode 100644 index 00000000..882a814d --- /dev/null +++ b/src/Lightweight/Model/Associations/HasManyThrough.hpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../../SqlComposedQuery.hpp" +#include "../AbstractRecord.hpp" +#include "../StringLiteral.hpp" + +namespace Model +{ + +template +class HasManyThrough +{ + public: + explicit HasManyThrough(AbstractRecord& record); + explicit HasManyThrough(AbstractRecord& record, HasManyThrough&& other) noexcept; + + [[nodiscard]] bool IsEmpty() const noexcept; + [[nodiscard]] size_t Count() const; + + std::vector& All(); + + template + void Each(Callback&& callback); + + TargetRecord& At(size_t index); + TargetRecord const& At(size_t index) const; + TargetRecord& operator[](size_t index); + TargetRecord const& operator[](size_t index) const; + + [[nodiscard]] bool IsLoaded() const noexcept; + void Load(); + void Reload(); + + private: + void RequireLoaded() const; + + AbstractRecord* m_record; + bool m_loaded = false; + std::vector m_models; +}; + +// {{{ inlines + +template +HasManyThrough::HasManyThrough(AbstractRecord& record): + m_record { &record } +{ +} + +template +HasManyThrough::HasManyThrough(AbstractRecord& record, + HasManyThrough&& other) noexcept: + m_record { &record }, + m_loaded { other.m_loaded }, + m_models { std::move(other.m_models) } +{ +} + +template +bool HasManyThrough::IsEmpty() const noexcept +{ + return Count() == 0; +} + +template +size_t HasManyThrough::Count() const +{ + if (IsLoaded()) + return m_models.size(); + + auto const targetRecord = TargetRecord(); + auto const throughRecordMeta = ThroughRecord(); // TODO: eliminate instances, allowing direct access to meta info + + return TargetRecord::Join(throughRecordMeta.TableName(), + LeftKeyName.value, + SqlQualifiedTableColumnName(targetRecord.TableName(), targetRecord.PrimaryKeyName())) + .Join(m_record->TableName(), + m_record->PrimaryKeyName(), + SqlQualifiedTableColumnName(throughRecordMeta.TableName(), RightKeyName.value)) + .Where(SqlQualifiedTableColumnName(m_record->TableName(), m_record->PrimaryKeyName()), m_record->Id()) + .Count(); +} + +template +inline std::vector& HasManyThrough::All() +{ + RequireLoaded(); + return m_models; +} + +template +TargetRecord& HasManyThrough::At(size_t index) +{ + RequireLoaded(); + return m_models.at(index); +} + +template +TargetRecord const& HasManyThrough::At(size_t index) const +{ + RequireLoaded(); + return m_models.at(index); +} + +template +TargetRecord& HasManyThrough::operator[](size_t index) +{ + RequireLoaded(); + return m_models[index]; +} + +template +TargetRecord const& HasManyThrough::operator[]( + size_t index) const +{ + RequireLoaded(); + return m_models[index]; +} + +template +template +inline void HasManyThrough::Each(Callback&& callback) +{ + if (IsLoaded()) + { + for (auto& model: m_models) + callback(model); + } + else + { + auto const targetRecord = TargetRecord(); + auto const throughRecordMeta = ThroughRecord(); + + TargetRecord::Join(throughRecordMeta.TableName(), + LeftKeyName.value, + SqlQualifiedTableColumnName(targetRecord.TableName(), targetRecord.PrimaryKeyName())) + .Join(m_record->TableName(), + m_record->PrimaryKeyName(), + SqlQualifiedTableColumnName(throughRecordMeta.TableName(), RightKeyName.value)) + .Where(SqlQualifiedTableColumnName(m_record->TableName(), m_record->PrimaryKeyName()), m_record->Id()) + .Each(callback); + } +} + +template +inline bool HasManyThrough::IsLoaded() const noexcept +{ + return m_loaded; +} + +template +void HasManyThrough::Load() +{ + if (m_loaded) + return; + + auto const targetRecord = TargetRecord(); + auto const throughRecordMeta = ThroughRecord(); + + m_models = + TargetRecord::Join(throughRecordMeta.TableName(), + LeftKeyName.value, + SqlQualifiedTableColumnName(targetRecord.TableName(), targetRecord.PrimaryKeyName())) + .Join(m_record->TableName(), + m_record->PrimaryKeyName(), + SqlQualifiedTableColumnName(throughRecordMeta.TableName(), RightKeyName.value)) + .Where(SqlQualifiedTableColumnName(m_record->TableName(), m_record->PrimaryKeyName()), m_record->Id()) + .All(); + + m_loaded = true; +} + +template +void HasManyThrough::Reload() +{ + m_loaded = false; + m_models.clear(); + Load(); +} + +template +void HasManyThrough::RequireLoaded() const +{ + if (!IsLoaded()) + const_cast(this)->Load(); +} + +// }}} + +} // namespace Model diff --git a/src/Lightweight/Model/Associations/HasOne.hpp b/src/Lightweight/Model/Associations/HasOne.hpp new file mode 100644 index 00000000..17973d87 --- /dev/null +++ b/src/Lightweight/Model/Associations/HasOne.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../../SqlError.hpp" +#include "../AbstractRecord.hpp" +#include "../StringLiteral.hpp" + +#include + +namespace Model +{ + +// Represents a column in a another table that refers to this record. +template +class HasOne final +{ + public: + explicit HasOne(AbstractRecord& record); + explicit HasOne(AbstractRecord& record, HasOne&& other); + + OtherRecord& operator*(); + OtherRecord* operator->(); + [[nodiscard]] bool IsLoaded() const; + + bool Load(); + void Reload(); + + private: + void RequireLoaded(); + + AbstractRecord* m_record; + std::shared_ptr m_otherRecord; + + // We decided to use shared_ptr here, because we do not want to require to know the size of the OtherRecord + // at declaration time. +}; + +// {{{ HasOne<> implementation + +template +HasOne::HasOne(AbstractRecord& record): + m_record { &record } +{ +} + +template +HasOne::HasOne(AbstractRecord& record, HasOne&& other): + m_record { &record }, + m_otherRecord { std::move(other.m_otherRecord) } +{ +} + +template +OtherRecord& HasOne::operator*() +{ + RequireLoaded(); + return *m_otherRecord; +} + +template +OtherRecord* HasOne::operator->() +{ + RequireLoaded(); + return &*m_otherRecord; +} + +template +bool HasOne::IsLoaded() const +{ + return m_otherRecord.get() != nullptr; +} + +template +bool HasOne::Load() +{ + if (m_otherRecord) + return true; + + auto foundRecord = OtherRecord::FindBy(TheForeignKeyName.value, m_record->Id()); + if (!foundRecord) + return false; + + m_otherRecord = std::make_shared(std::move(foundRecord.value())); + return true; +} + +template +void HasOne::Reload() +{ + m_otherRecord.reset(); + Load(); +} + +template +void HasOne::RequireLoaded() +{ + if (!m_otherRecord) + Load(); +} + +// }}} + +} // namespace Model diff --git a/src/Lightweight/Model/Associations/HasOneThrough.hpp b/src/Lightweight/Model/Associations/HasOneThrough.hpp new file mode 100644 index 00000000..a793b802 --- /dev/null +++ b/src/Lightweight/Model/Associations/HasOneThrough.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../../SqlComposedQuery.hpp" +#include "../AbstractRecord.hpp" +#include "../StringLiteral.hpp" + +#include + +namespace Model +{ + +template +class HasOneThrough +{ + public: + explicit HasOneThrough(AbstractRecord& record); + HasOneThrough(AbstractRecord& record, HasOneThrough&& other) noexcept; + + OtherRecord& operator*(); + OtherRecord* operator->(); + + [[nodiscard]] bool IsLoaded() const noexcept; + void Load(); + void Reload(); + + private: + AbstractRecord* m_record; + std::shared_ptr m_otherRecord; +}; + +// {{{ inlines + +template +HasOneThrough::HasOneThrough(AbstractRecord& record): + m_record { &record } +{ +} + +template +HasOneThrough::HasOneThrough(AbstractRecord& record, + HasOneThrough&& other) noexcept: + m_record { &record }, + m_otherRecord { std::move(other.m_otherRecord) } +{ +} + +template +OtherRecord& HasOneThrough::operator*() +{ + if (!m_otherRecord) + Load(); + + return *m_otherRecord; +} + +template +OtherRecord* HasOneThrough::operator->() +{ + if (!m_otherRecord) + Load(); + + return &*m_otherRecord; +} + +template +bool HasOneThrough::IsLoaded() const noexcept +{ + return m_otherRecord.get(); +} + +template +void HasOneThrough::Load() +{ + if (IsLoaded()) + return; + + auto result = + OtherRecord::template Join() + .Where(SqlQualifiedTableColumnName(OtherRecord().TableName(), ForeignKeyName.value), m_record->Id().value) + .First(); + + if (!result.has_value()) + { + SqlLogger::GetLogger().OnWarning(std::format("No data found on table {} for {} = {}", + OtherRecord().TableName(), + ForeignKeyName.value, + m_record->Id().value)); + return; + } + + m_otherRecord = std::make_shared(std::move(result.value())); +} + +template +void HasOneThrough::Reload() +{ + m_otherRecord.reset(); + return Load(); +} + +// }}} + +} // namespace Model diff --git a/src/Lightweight/Model/ColumnType.hpp b/src/Lightweight/Model/ColumnType.hpp new file mode 100644 index 00000000..fb2c070e --- /dev/null +++ b/src/Lightweight/Model/ColumnType.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../SqlDataBinder.hpp" +#include "../SqlTraits.hpp" +#include "RecordId.hpp" + +#include +#include +#include + +namespace Model +{ + +namespace detail +{ + template + struct ColumnTypeOf; + + // clang-format off +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::CHAR; }; +template struct ColumnTypeOf> { static constexpr SqlColumnType value = SqlColumnType::STRING; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::STRING; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::STRING; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::TEXT; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::BOOLEAN; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::INTEGER; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::INTEGER; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::INTEGER; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::INTEGER; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::INTEGER; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::INTEGER; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::REAL; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::REAL; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::DATE; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::TIME; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::DATETIME; }; +template <> struct ColumnTypeOf { static constexpr SqlColumnType value = SqlColumnType::INTEGER; }; + // clang-format on +} // namespace detail + +template +constexpr SqlColumnType ColumnTypeOf = detail::ColumnTypeOf::value; + +} // namespace Model diff --git a/src/Lightweight/Model/Detail.hpp b/src/Lightweight/Model/Detail.hpp new file mode 100644 index 00000000..5a38d298 --- /dev/null +++ b/src/Lightweight/Model/Detail.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include + +namespace Model::detail +{ + +struct StringBuilder +{ + std::string output; + + std::string operator*() const& noexcept + { + return output; + } + + std::string operator*() && noexcept + { + return std::move(output); + } + + [[nodiscard]] bool empty() const noexcept + { + return output.empty(); + } + + template + StringBuilder& operator<<(T&& value) + { + if constexpr (std::is_same_v || std::is_same_v + || std::is_same_v) + output += value; + else + output += std::format("{}", std::forward(value)); + return *this; + } +}; + +} // namespace Model::detail diff --git a/src/Lightweight/Model/Field.hpp b/src/Lightweight/Model/Field.hpp new file mode 100644 index 00000000..000f0538 --- /dev/null +++ b/src/Lightweight/Model/Field.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../SqlStatement.hpp" +#include "AbstractField.hpp" +#include "AbstractRecord.hpp" +#include "ColumnType.hpp" +#include "StringLiteral.hpp" + +#include + +namespace Model +{ + +template +struct Record; + +// Represents a single column in a table. +// +// The column name, index, and type are known at compile time. +// If either name or index are not known at compile time, leave them at their default values, +// but at least one of them msut be known. +template +class Field: public AbstractField +{ + public: + explicit Field(AbstractRecord& record): + AbstractField { + record, TheTableColumnIndex, TheColumnName.value, ColumnTypeOf, TheRequirement, + } + { + record.RegisterField(*this); + } + + Field(Field const& other): + AbstractField { + const_cast(other).GetRecord(), + TheTableColumnIndex, + TheColumnName.value, + ColumnTypeOf, + TheRequirement, + }, + m_value { other.m_value } + { + GetRecord().RegisterField(*this); + } + + Field(AbstractRecord& record, Field&& other): + AbstractField { std::move(static_cast(other)) }, + m_value { std::move(other.m_value) } + { + record.RegisterField(*this); + } + + Field() = delete; + Field(Field&& other) = delete; + Field& operator=(Field&& other) = delete; + Field& operator=(Field const& other) = delete; + ~Field() override = default; + + // clang-format off + + template + auto operator<=>(Field const& other) const noexcept { return m_value <=> other.m_value; } + + // We also define the equality and inequality operators explicitly, because <=> from above does not seem to work in MSVC VS 2022. + template + auto operator==(Field const& other) const noexcept { return m_value == other.m_value; } + + template + auto operator!=(Field const& other) const noexcept { return m_value != other.m_value; } + + bool operator==(T const& other) const noexcept { return m_value == other; } + bool operator!=(T const& other) const noexcept { return m_value != other; } + + T const& Value() const noexcept { return m_value; } + void SetData(T&& value) { SetModified(true); m_value = std::move(value); } + void SetNull() { SetModified(true); m_value = T {}; } + + Field& operator=(T&& value) noexcept; + + T& operator*() noexcept { return m_value; } + T const& operator*() const noexcept { return m_value; } + + // clang-format on + + [[nodiscard]] std::string InspectValue() const override; + void BindInputParameter(SQLSMALLINT parameterIndex, SqlStatement& stmt) const override; + void BindOutputColumn(SqlStatement& stmt) override; + void BindOutputColumn(SQLSMALLINT index, SqlStatement& stmt) override; + + void LoadValueFrom(AbstractField& other) override + { + assert(Type() == other.Type()); + m_value = std::move(static_cast(other).m_value); + } + + private: + T m_value {}; +}; + +// {{{ Field<> implementation + +template +Field& Field::operator=(T&& value) noexcept +{ + SetModified(true); + m_value = std::move(value); + return *this; +} + +template +std::string Field::InspectValue() const +{ + if constexpr (std::is_same_v) + { + std::stringstream result; + result << std::quoted(m_value, '\''); + return result.str(); + } + else if constexpr (std::is_same_v) + { + std::stringstream result; + result << std::quoted(m_value.value, '\''); + return result.str(); + } + else if constexpr (std::is_same_v) + { + std::stringstream result; + result << std::quoted(m_value.value, '\''); + return result.str(); + } + else if constexpr (std::is_same_v) + return std::format("\'{}\'", m_value.value); + else if constexpr (std::is_same_v) + return std::format("\'{}\'", m_value.value); + else if constexpr (std::is_same_v) + return std::format("\'{}\'", m_value.value()); + else + return std::format("{}", m_value); +} + +template +void Field::BindInputParameter(SQLSMALLINT parameterIndex, + SqlStatement& stmt) const +{ + return stmt.BindInputParameter(parameterIndex, m_value); +} + +template +void Field::BindOutputColumn(SqlStatement& stmt) +{ + stmt.BindOutputColumn(TheTableColumnIndex, &m_value); +} + +template +void Field::BindOutputColumn(SQLSMALLINT outputIndex, + SqlStatement& stmt) +{ + stmt.BindOutputColumn(outputIndex, &m_value); +} + +// }}} + +} // namespace Model diff --git a/src/Lightweight/Model/Logger.cpp b/src/Lightweight/Model/Logger.cpp new file mode 100644 index 00000000..af03a3bd --- /dev/null +++ b/src/Lightweight/Model/Logger.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +#include "Detail.hpp" +#include "Field.hpp" +#include "Logger.hpp" + +#include +#include +#include +#include + +namespace Model +{ + +class StandardQueryLogger: public QueryLogger +{ + private: + std::chrono::steady_clock::time_point m_startedAt; + std::string m_query; + std::vector m_output; + size_t m_rowCount {}; + + public: + void QueryStart(std::string_view query, std::vector const& output) override + { + m_startedAt = std::chrono::steady_clock::now(); + m_query = query; + m_output = output; + m_rowCount = 0; + } + + void QueryNextRow(AbstractRecord const& /*record*/) override + { + ++m_rowCount; + } + + void QueryEnd() override + { + auto const stoppedAt = std::chrono::steady_clock::now(); + auto const duration = std::chrono::duration_cast(stoppedAt - m_startedAt); + auto const seconds = std::chrono::duration_cast(duration); + auto const microseconds = std::chrono::duration_cast(duration - seconds); + auto const durationStr = std::format("{}.{:06}", seconds.count(), microseconds.count()); + + auto const rowCountStr = m_rowCount == 0 ? "" + : m_rowCount == 1 ? " [1 row]" + : std::format(" [{} rows]", m_rowCount); + + if (m_output.empty()) + { + std::println("[{}]{} {}", durationStr, rowCountStr, m_query); + return; + } + + detail::StringBuilder output; + + for (AbstractField const* field: m_output) + { + if (!output.empty()) + output << ", "; + output << field->Name().name << '=' << field->InspectValue(); + } + + std::println("[{}]{} {} WITH [{}]", durationStr, rowCountStr, m_query, *output); + } +}; + +static QueryLogger theNullLogger; + +QueryLogger* QueryLogger::NullLogger() noexcept +{ + return &theNullLogger; +} + +static StandardQueryLogger theStandardLogger; + +QueryLogger* QueryLogger::StandardLogger() noexcept +{ + return &theStandardLogger; +} + +QueryLogger* QueryLogger::m_instance = QueryLogger::NullLogger(); + +} // namespace Model diff --git a/src/Lightweight/Model/Logger.hpp b/src/Lightweight/Model/Logger.hpp new file mode 100644 index 00000000..4994eea4 --- /dev/null +++ b/src/Lightweight/Model/Logger.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include + +namespace Model +{ + +class AbstractField; +struct AbstractRecord; + +class QueryLogger +{ + public: + virtual ~QueryLogger() = default; + + using FieldList = std::vector; + + virtual void QueryStart(std::string_view /*query*/, FieldList const& /*output*/) {} + virtual void QueryNextRow(AbstractRecord const& /*model*/) {} + virtual void QueryEnd() {} + + static void Set(QueryLogger* next) noexcept + { + m_instance = next; + } + + static QueryLogger& Get() noexcept + { + return *m_instance; + } + + static QueryLogger* NullLogger() noexcept; + static QueryLogger* StandardLogger() noexcept; + + private: + static QueryLogger* m_instance; +}; + +namespace detail +{ + + struct SqlScopedModelQueryLogger + { + using FieldList = QueryLogger::FieldList; + + SqlScopedModelQueryLogger(std::string_view query, FieldList const& output) + { + QueryLogger::Get().QueryStart(query, output); + } + + SqlScopedModelQueryLogger& operator+=(AbstractRecord const& model) + { + QueryLogger::Get().QueryNextRow(model); + return *this; + } + + ~SqlScopedModelQueryLogger() + { + QueryLogger::Get().QueryEnd(); + } + }; + +} // namespace detail + +} // namespace Model diff --git a/src/Lightweight/Model/README.md b/src/Lightweight/Model/README.md new file mode 100644 index 00000000..5d1f344e --- /dev/null +++ b/src/Lightweight/Model/README.md @@ -0,0 +1,82 @@ +# Lightweight Model API + +The Lightweight Model API is a lightweight ORM that allows you to interact with your database using a simple and intuitive API. +It is designed to be easy to use while being as efficient as possible. + +This API was inspired by **Active Record** pattern and API from Ruby on Rails. + +## Features + +- **Simple & Intuitive API**: The API is designed to be as simple as possible. +- **Efficient**: The API is designed to be as efficient as possible. + +## Example + +```cpp +#include +#include + +struct Book; + +struct Author: Model::Record +{ + Model::Field name; + Model::HasMany books; +}; + +struct Book: Model::Record +{ + Model::Field title; + Model::Field isbn; + Model::BelongsTo author; +}; + +void demo() +{ + Model::CreateTables(); + + Author author; + author.name = "Bjarne Stroustrup"; + author.Save().or_else(std::abort); + + Book book; + book.title = "The C++ Programming Language"; + book.isbn = "978-0-321-56384-2"; + book.author = author; + book.Save().or_else(std::abort); + + auto books = Book::All().or_else(std::abort); + for (auto book: books) + std::println("{}", book); + + std::println("{} has {} books", author.name, author.books.Count()); + + author.Destroy(); + book.Destroy(); +} +``` + +## TODO: Open Refactors + +- [ ] drop SqlResult from Model API and use retry pattern instead, throw otherwise +- [ ] Consider reintroducing `Model::Record` +- [x] Support (join) Associations (https://guides.rubyonrails.org/association_basics.html) (e.g. `Author::All().Join().Where() == 42; v.As() == "42";`) + +## Open TODOs + +- [x] [Lightweight] Add custom type `SqlText` for `TEXT` fields +- [x] Add std::formatter for `Record` +- [x] Add test for `BelongsTo<>` (including eager & lazy-loading check) +- [x] Add test for `HasOne<>` (including eager & lazy-loading check) +- [x] Add test for `HasMany<>` (including eager & lazy-loading check) +- [x] Differenciate between VARCHAR (`std::string`) and TEXT (maybe `SqlText`?) +- [x] Make logging more useful, adding payload data +- [x] remove debug prints +- [x] add proper trace logging +- [x] Add `HasManyThrough<>` association +- [x] Add `HasOneThrough<>` association +- [ ] Add `HasAndBelongsToMany<>` association +- [ ] Add SQL query caching +- [x] Add lazy loading constraints (e.g. something similar to `Book::All().Where`) +- [ ] Add ability to configure PK's auto-increment to be server-side (default) vs client-side. this must be a compile-time option (via template parameter) diff --git a/src/Lightweight/Model/Record.hpp b/src/Lightweight/Model/Record.hpp new file mode 100644 index 00000000..6b282220 --- /dev/null +++ b/src/Lightweight/Model/Record.hpp @@ -0,0 +1,672 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../SqlComposedQuery.hpp" +#include "AbstractRecord.hpp" +#include "Detail.hpp" +#include "Field.hpp" +#include "Logger.hpp" + +#include +#include +#include +#include +#include +#include + +namespace Model +{ + +enum class SqlWhereOperator : uint8_t +{ + EQUAL, + NOT_EQUAL, + LESS_THAN, + LESS_OR_EQUAL, + GREATER_THAN, + GREATER_OR_EQUAL +}; + +constexpr std::string_view sqlOperatorString(SqlWhereOperator value) noexcept +{ + using namespace std::string_view_literals; + + auto constexpr mappings = std::array { + "="sv, "!="sv, "<"sv, "<="sv, ">"sv, ">="sv, + }; + + std::string_view result; + + if (std::to_underlying(value) < mappings.size()) + result = mappings[std::to_underlying(value)]; + + return result; +} + +// API to load records with more complex constraints. +// +// @see Record +// @see Record::Join() +template +class RecordQueryBuilder +{ + private: + explicit RecordQueryBuilder(SqlQueryBuilder queryBuilder): + m_queryBuilder { std::move(queryBuilder) } + { + } + + public: + explicit RecordQueryBuilder(): + m_queryBuilder { SqlQueryBuilder::From(TargetModel().TableName()) } + { + } + + template + [[nodiscard]] RecordQueryBuilder Join() && + { + auto const joinModel = JoinModel(); + (void) m_queryBuilder.InnerJoin(joinModel.TableName(), joinModel.PrimaryKeyName(), foreignKeyColumn.value); + return RecordQueryBuilder { std::move(m_queryBuilder) }; + } + + [[nodiscard]] RecordQueryBuilder Join(std::string_view joinTableName, + std::string_view joinTablePrimaryKey, + SqlQualifiedTableColumnName onComparisonColumn) && + { + (void) m_queryBuilder.InnerJoin(joinTableName, joinTablePrimaryKey, onComparisonColumn); + return RecordQueryBuilder { std::move(m_queryBuilder) }; + } + + [[nodiscard]] RecordQueryBuilder Join(std::string_view joinTableName, + std::string_view joinTablePrimaryKey, + std::string_view onComparisonColumn) && + { + (void) m_queryBuilder.InnerJoin(joinTableName, joinTablePrimaryKey, onComparisonColumn); + return RecordQueryBuilder { std::move(m_queryBuilder) }; + } + + template + [[nodiscard]] RecordQueryBuilder Where(ColumnName const& columnName, + SqlWhereOperator whereOperator, + T const& value) && + { + (void) m_queryBuilder.Where(columnName, sqlOperatorString(whereOperator), value); + return RecordQueryBuilder { std::move(m_queryBuilder) }; + } + + template + [[nodiscard]] RecordQueryBuilder Where(ColumnName const& columnName, T const& value) && + { + (void) m_queryBuilder.Where(columnName, value); + return *this; + } + + [[nodiscard]] RecordQueryBuilder OrderBy(std::string_view columnName, + SqlResultOrdering ordering = SqlResultOrdering::ASCENDING) && + { + (void) m_queryBuilder.OrderBy(columnName, ordering); + return *this; + } + + [[nodiscard]] std::size_t Count() + { + auto stmt = SqlStatement(); + auto const sqlQueryString = m_queryBuilder.Count().ToSql(stmt.Connection().QueryFormatter()); + auto const scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, {}); + return stmt.ExecuteDirectScalar(sqlQueryString).value(); + } + + [[nodiscard]] std::optional First(size_t count = 1) + { + TargetModel targetRecord; + + auto stmt = SqlStatement {}; + + auto const sqlQueryString = m_queryBuilder.Select(targetRecord.AllFieldNames(), targetRecord.TableName()) + .First(count) + .ToSql(stmt.Connection().QueryFormatter()); + + auto const _ = detail::SqlScopedModelQueryLogger(sqlQueryString, {}); + + stmt.Prepare(sqlQueryString); + stmt.Execute(); + + stmt.BindOutputColumn(1, &targetRecord.MutableId().value); + for (AbstractField* field: targetRecord.AllFields()) + field->BindOutputColumn(stmt); + + if (!stmt.FetchRow()) + return std::nullopt; + + return { std::move(targetRecord) }; + } + + [[nodiscard]] std::vector Range(std::size_t offset, std::size_t limit) + { + auto const targetRecord = TargetModel(); + auto const sqlQueryString = m_queryBuilder.Select(targetRecord.AllFieldNames(), targetRecord.TableName()) + .Range(offset, limit) + .ToSql(SqlConnection().QueryFormatter()); + return TargetModel::Query(sqlQueryString).value_or(std::vector {}); + } + + template + void Each(Callback&& callback) + { + auto const targetRecord = TargetModel(); + auto const sqlQueryString = m_queryBuilder.Select(targetRecord.AllFieldNames(), targetRecord.TableName()) + .All() + .ToSql(SqlConnection().QueryFormatter()); + TargetModel::Each(std::forward(callback), sqlQueryString); + } + + [[nodiscard]] std::vector All() + { + auto const targetRecord = TargetModel(); + auto const sqlQueryString = m_queryBuilder.Select(targetRecord.AllFieldNames(), targetRecord.TableName()) + .All() + .ToSql(SqlConnection().QueryFormatter()); + return TargetModel::Query(sqlQueryString); + } + + private: + SqlQueryBuilder m_queryBuilder; +}; + +template +struct Record: AbstractRecord +{ + public: + static constexpr auto Nullable = Model::SqlNullable; + static constexpr auto NotNullable = Model::SqlNotNullable; + + Record() = delete; + Record(Record const&) = default; + Record& operator=(Record const&) = delete; + Record& operator=(Record&&) = default; + ~Record() = default; + + Record(Record&& other) noexcept: + AbstractRecord(std::move(other)) + { + } + + // Creates (or recreates a copy of) the model in the database. + RecordId Create(); + + // Reads the model from the database by given model ID. + bool Load(RecordId id); + + // Re-reads the model from the database. + void Reload(); + + // Reads the model from the database by given column name and value. + template + bool Load(std::string_view const& columnName, T const& value); + + // Updates the model in the database. + void Update(); + + // Creates or updates the model in the database, depending on whether it already exists. + void Save(); + + // Deletes the model from the database. + void Destroy(); + + // Updates all models with the given changes in the modelChanges model. + static void UpdateAll(Derived const& modelChanges) noexcept; + + // Retrieves the first model from the database (ordered by ID ASC). + static std::optional First(size_t count = 1); + + // Retrieves the last model from the database (ordered by ID ASC). + static std::optional Last(); + + // Retrieves the model with the given ID from the database. + static std::optional Find(RecordId id); + + template + static std::optional FindBy(ColumnName const& columnName, T const& value); + + // Retrieves all models of this kind from the database. + static std::vector All() noexcept; + + // Retrieves the number of models of this kind from the database. + static size_t Count() noexcept; + + static RecordQueryBuilder Build(); + + // Joins the model with this record's model and returns a proxy for further joins and actions on this join. + template + static RecordQueryBuilder Join(); + + static RecordQueryBuilder Join(std::string_view joinTable, + std::string_view joinColumnName, + SqlQualifiedTableColumnName onComparisonColumn); + + template + static RecordQueryBuilder Where(std::string_view columnName, + SqlWhereOperator whereOperator, + Value const& value); + + template + static RecordQueryBuilder Where(std::string_view columnName, Value const& value); + + // Invokes a callback for each model that matches the given query string. + template + static void Each(Callback&& callback, std::string_view sqlQueryString, InputParameters&&... inputParameters); + + template + static std::vector Query(std::string_view sqlQueryString, InputParameters&&... inputParameters); + + // Returns the SQL string to create the table for this model. + static std::string CreateTableString(SqlServerType serverType); + + // Creates the table for this model from the database. + static void CreateTable(); + + // Drops the table for this model from the database. + static void DropTable(); + + protected: + explicit Record(std::string_view tableName, std::string_view primaryKey = "id"); +}; + +// {{{ Record<> implementation + +template +Record::Record(std::string_view tableName, std::string_view primaryKey): + AbstractRecord { tableName, primaryKey, RecordId {} } +{ +} + +template +size_t Record::Count() noexcept +{ + SqlStatement stmt; + return stmt.ExecuteDirectScalar(std::format("SELECT COUNT(*) FROM {}", Derived().TableName())).value(); +} + +template +std::optional Record::Find(RecordId id) +{ + static_assert(std::is_move_constructible_v, + "The model `Derived` must be move constructible for Find() to return the model."); + Derived model; + if (!model.Load(id)) + return std::nullopt; + return { std::move(model) }; +} + +template +template +std::optional Record::FindBy(ColumnName const& columnName, T const& value) +{ + static_assert(std::is_move_constructible_v, + "The model `Derived` must be move constructible for Find() to return the model."); + Derived model; + if (!model.Load(columnName, value)) + return std::nullopt; + return { std::move(model) }; +} + +template +std::vector Record::All() noexcept +{ + // Require that the model is copy constructible. Simply add a default move constructor to the model if it is not. + static_assert(std::is_move_constructible_v, + "The model `Derived` must be move constructible for All() to copy elements into the result."); + + std::vector allModels; + + Derived const modelSchema; + + detail::StringBuilder sqlColumnsString; + sqlColumnsString << modelSchema.PrimaryKeyName(); + for (AbstractField const* field: modelSchema.AllFields()) + sqlColumnsString << ", " << field->Name(); + + SqlStatement stmt; + + auto const sqlQueryString = std::format("SELECT {} FROM {}", *sqlColumnsString, modelSchema.TableName()); + + auto scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, {}); + + stmt.Prepare(sqlQueryString); + stmt.Execute(); + + while (true) + { + Derived record; + + stmt.BindOutputColumn(1, &record.m_data->id.value); + for (AbstractField* field: record.AllFields()) + field->BindOutputColumn(stmt); + + if (!stmt.FetchRow()) + break; + + scopedModelSqlLogger += record; + + allModels.emplace_back(std::move(record)); + } + + return allModels; +} + +template +std::string Record::CreateTableString(SqlServerType serverType) +{ + SqlTraits const& traits = GetSqlTraits(serverType); // TODO: take server type from connection + detail::StringBuilder sql; + auto model = Derived(); + model.SortFieldsByIndex(); + + // TODO: verify that field indices are unique, contiguous, and start at 1 + // TODO: verify that the primary key is the first field + // TODO: verify that the primary key is not nullable + + sql << "CREATE TABLE " << model.TableName() << " (\n"; + + sql << " " << model.PrimaryKeyName() << " " << traits.PrimaryKeyAutoIncrement << ",\n"; + + std::vector sqlConstraints; + + for (auto const* field: model.AllFields()) + { + sql << " " << field->Name() << " " << traits.ColumnTypeName(field->Type()); + + if (field->IsNullable()) + sql << " NULL"; + else + sql << " NOT NULL"; + + if (auto constraint = field->SqlConstraintSpecifier(); !constraint.empty()) + sqlConstraints.emplace_back(std::move(constraint)); + + if (field != model.AllFields().back() || !sqlConstraints.empty()) + sql << ","; + sql << "\n"; + } + + for (auto const& constraint: sqlConstraints) + { + sql << " " << constraint; + if (&constraint != &sqlConstraints.back()) + sql << ","; + sql << "\n"; + } + + sql << ");\n"; + + return *sql; +} + +template +void Record::CreateTable() +{ + auto stmt = SqlStatement {}; + auto const sqlQueryString = CreateTableString(stmt.Connection().ServerType()); + auto const scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, {}); + stmt.ExecuteDirect(sqlQueryString); +} + +template +void Record::DropTable() +{ + auto const sqlQueryString = std::format("DROP TABLE \"{}\"", Derived().TableName()); + auto const scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, {}); + SqlStatement().ExecuteDirect(sqlQueryString); +} + +template +RecordId Record::Create() +{ + auto stmt = SqlStatement {}; + + auto const modifiedFields = GetModifiedFields(); + + detail::StringBuilder sqlColumnsString; + detail::StringBuilder sqlValuesString; + for (auto const* field: modifiedFields) + { + if (!field->IsModified()) + { + // if (field->IsNull() && field->IsRequired()) + // { + // SqlLogger::GetLogger().OnWarning( // TODO + // std::format("Model required field not given: {}.{}", TableName(), field->Name())); + // return std::unexpected { SqlError::FAILURE }; // TODO: return + // SqlError::MODEL_REQUIRED_FIELD_NOT_GIVEN; + // } + continue; + } + + if (!sqlColumnsString.empty()) + { + sqlColumnsString << ", "; + sqlValuesString << ", "; + } + + sqlColumnsString << field->Name(); + sqlValuesString << "?"; + } + + auto const sqlInsertStmtString = + std::format("INSERT INTO {} ({}) VALUES ({})", TableName(), *sqlColumnsString, *sqlValuesString); + + auto const scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlInsertStmtString, modifiedFields); + + stmt.Prepare(sqlInsertStmtString); + + for (auto const&& [parameterIndex, field]: modifiedFields | std::views::enumerate) + field->BindInputParameter(parameterIndex + 1, stmt); + + stmt.Execute(); + + for (auto* field: AllFields()) + field->SetModified(false); + + // Update the model's ID with the last insert ID + m_data->id = RecordId { .value = stmt.LastInsertId() }; + return m_data->id; +} + +template +bool Record::Load(RecordId id) +{ + return Load(PrimaryKeyName(), id.value); +} + +template +void Record::Reload() +{ + Load(PrimaryKeyName(), Id()); +} + +template +template +bool Record::Load(std::string_view const& columnName, T const& value) +{ + SqlStatement stmt; + + auto const sqlQueryString = SqlQueryBuilder::From(TableName()) + .Select(AllFieldNames()) + .Where(columnName, SqlQueryWildcard()) + .First() + .ToSql(stmt.Connection().QueryFormatter()); + + auto const scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, AllFields()); + + stmt.Prepare(sqlQueryString); + stmt.BindInputParameter(1, value); + stmt.BindOutputColumn(1, &m_data->id.value); + for (AbstractField* field: AllFields()) + field->BindOutputColumn(stmt); + stmt.Execute(); + return stmt.FetchRow(); +} + +template +void Record::Update() +{ + auto sqlColumnsString = detail::StringBuilder {}; + auto modifiedFields = GetModifiedFields(); + + for (AbstractField* field: modifiedFields) + { + if (!field->IsModified()) + continue; + + if (!sqlColumnsString.empty()) + sqlColumnsString << ", "; + + sqlColumnsString << field->Name() << " = ?"; + } + + auto stmt = SqlStatement {}; + + auto const sqlQueryString = + std::format("UPDATE {} SET {} WHERE {} = {}", TableName(), *sqlColumnsString, PrimaryKeyName(), Id()); + + auto const scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, modifiedFields); + + stmt.Prepare(sqlQueryString); + + for (auto const&& [index, field]: modifiedFields | std::views::enumerate) + field->BindInputParameter(index + 1, stmt); + + stmt.Execute(); + + for (auto* field: modifiedFields) + field->SetModified(false); +} + +template +void Record::Save() +{ + if (Id().value != 0) + return Update(); + + Create(); +} + +template +void Record::Destroy() +{ + auto const sqlQueryString = std::format("DELETE FROM {} WHERE {} = {}", TableName(), PrimaryKeyName(), *Id()); + auto const scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, {}); + auto stmt = SqlStatement {}; + auto const& sqlTraits = stmt.Connection().Traits(); + stmt.ExecuteDirect(sqlTraits.EnforceForeignKeyConstraint); + stmt.ExecuteDirect(sqlQueryString); +} + +template +RecordQueryBuilder Record::Build() +{ + return RecordQueryBuilder(); +} + +template +template +RecordQueryBuilder Record::Join() +{ + return RecordQueryBuilder().template Join(); +} + +template +RecordQueryBuilder Record::Join(std::string_view joinTable, + std::string_view joinColumnName, + SqlQualifiedTableColumnName onComparisonColumn) +{ + return RecordQueryBuilder().Join(joinTable, joinColumnName, onComparisonColumn); +} + +template +template +RecordQueryBuilder Record::Where(std::string_view columnName, Value const& value) +{ + return Where(columnName, SqlWhereOperator::EQUAL, value); +} + +template +template +RecordQueryBuilder Record::Where(std::string_view columnName, + SqlWhereOperator whereOperator, + Value const& value) +{ + static_assert(std::is_move_constructible_v, + "The model `Derived` must be move constructible for Where() to return the models."); + +#if 1 + return RecordQueryBuilder().Where(columnName, whereOperator, value); +#else + std::vector allModels; + + Derived modelSchema; + + detail::StringBuilder sqlColumnsString; + sqlColumnsString << modelSchema.PrimaryKeyName(); + for (AbstractField const* field: modelSchema.AllFields()) + sqlColumnsString << ", " << field->Name(); + + auto const sqlQueryString = std::format("SELECT {} FROM {} WHERE \"{}\" {} ?", + *sqlColumnsString, + modelSchema.TableName(), + columnName, + sqlOperatorString(whereOperator)); + return Query(sqlQueryString, value); +#endif +} + +template +template +std::vector Record::Query(std::string_view sqlQueryString, InputParameters&&... inputParameters) +{ + static_assert(std::is_move_constructible_v, + "The model `Derived` must be move constructible for Where() to return the models."); + + std::vector output; + Each([&output](Derived& model) { output.push_back(std::move(model)); }, + sqlQueryString, + std::forward(inputParameters)...); + return { std::move(output) }; +} + +template +template +void Record::Each(Callback&& callback, std::string_view sqlQueryString, InputParameters&&... inputParameters) +{ + SqlStatement stmt; + + auto scopedModelSqlLogger = detail::SqlScopedModelQueryLogger(sqlQueryString, {}); + + stmt.Prepare(sqlQueryString); + + SQLSMALLINT inputParameterPosition = 0; + (stmt.BindInputParameter(++inputParameterPosition, std::forward(inputParameters)), ...); + + stmt.Execute(); + + while (true) + { + Derived record; + + stmt.BindOutputColumn(1, &record.m_data->id.value); + for (AbstractField* field: record.AllFields()) + field->BindOutputColumn(stmt); + + if (!stmt.FetchRow()) + break; + + scopedModelSqlLogger += record; + + callback(record); + } +} + +// }}} + +} // namespace Model diff --git a/src/Lightweight/Model/RecordId.hpp b/src/Lightweight/Model/RecordId.hpp new file mode 100644 index 00000000..115a18ce --- /dev/null +++ b/src/Lightweight/Model/RecordId.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../../Lightweight/SqlDataBinder.hpp" + +#include + +namespace Model +{ + +// Represents a unique identifier of a specific record in a table. +struct RecordId +{ + using InnerType = size_t; + InnerType value; + + constexpr InnerType operator*() const noexcept + { + return value; + } + + constexpr std::weak_ordering operator<=>(RecordId const& other) const noexcept = default; + + constexpr bool operator==(RecordId other) const noexcept + { + return value == other.value; + } + + constexpr bool operator==(InnerType other) const noexcept + { + return value == other; + } +}; + +} // namespace Model + +template +struct WhereConditionLiteralType; + +template <> +struct WhereConditionLiteralType +{ + constexpr static bool needsQuotes = false; +}; + +template <> +struct SqlDataBinder +{ + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLSMALLINT column, Model::RecordId const& value) + { + return SqlDataBinder::InputParameter(stmt, column, value.value); + } + + static SQLRETURN OutputColumn( + SQLHSTMT stmt, SQLSMALLINT column, Model::RecordId* result, SQLLEN* indicator, SqlDataBinderCallback& cb) + { + return SqlDataBindervalue)>::OutputColumn(stmt, column, &result->value, indicator, cb); + } + + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLSMALLINT column, Model::RecordId* result, SQLLEN* indicator) + { + return SqlDataBindervalue)>::GetColumn(stmt, column, &result->value, indicator); + } +}; + +template <> +struct std::formatter: std::formatter +{ + auto format(Model::RecordId id, format_context& ctx) const -> format_context::iterator + { + return formatter::format(id.value, ctx); + } +}; diff --git a/src/Lightweight/Model/StringLiteral.hpp b/src/Lightweight/Model/StringLiteral.hpp new file mode 100644 index 00000000..55558aae --- /dev/null +++ b/src/Lightweight/Model/StringLiteral.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include + +namespace Model +{ + +template +struct StringLiteral +{ + constexpr StringLiteral(const char (&str)[N]) noexcept + { + std::copy_n(str, N, value); + } + + constexpr std::string_view operator*() const noexcept + { + return value; + } + + char value[N]; +}; + +} // namespace Model diff --git a/src/Lightweight/Model/Utils.hpp b/src/Lightweight/Model/Utils.hpp new file mode 100644 index 00000000..c3ec8e6f --- /dev/null +++ b/src/Lightweight/Model/Utils.hpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "../SqlConnection.hpp" +#include "./Detail.hpp" +#include "./Record.hpp" + +#include + +namespace Model +{ + +template +std::string CreateSqlTablesString(SqlServerType serverType) +{ + detail::StringBuilder result; + result << ((Models::CreateTableString(serverType) << "\n") << ...); + return *result; +} + +template +void CreateSqlTables() +{ + (Models::CreateTable(), ...); +} + +} // namespace Model diff --git a/src/Lightweight/SqlComposedQuery.cpp b/src/Lightweight/SqlComposedQuery.cpp new file mode 100644 index 00000000..dc252cfe --- /dev/null +++ b/src/Lightweight/SqlComposedQuery.cpp @@ -0,0 +1,183 @@ +#include "SqlComposedQuery.hpp" +#include "SqlQueryFormatter.hpp" + +#include + +// {{{ SqlQueryBuilder impl + +SqlQueryBuilder SqlQueryBuilder::From(std::string_view table) +{ + return SqlQueryBuilder(table); +} + +SqlQueryBuilder::SqlQueryBuilder(std::string_view table) +{ + m_query.table = table; +} + +SqlQueryBuilder& SqlQueryBuilder::Select(std::vector const& fieldNames) +{ + for (auto const& fieldName: fieldNames) + { + if (!m_query.fields.empty()) + m_query.fields += ", "; + + m_query.fields += '"'; + m_query.fields += fieldName; + m_query.fields += '"'; + } + return *this; +} + +SqlQueryBuilder& SqlQueryBuilder::Select(std::vector const& fieldNames, std::string_view tableName) +{ + for (auto const& fieldName: fieldNames) + { + if (!m_query.fields.empty()) + m_query.fields += ", "; + + m_query.fields += '"'; + m_query.fields += tableName; + m_query.fields += "\".\""; + m_query.fields += fieldName; + m_query.fields += '"'; + } + return *this; +} + +SqlQueryBuilder& SqlQueryBuilder::InnerJoin(std::string_view joinTable, + std::string_view joinColumnName, + SqlQualifiedTableColumnName onOtherColumn) +{ + m_query.tableJoins += std::format("\n " + R"(INNER JOIN "{0}" ON "{0}"."{1}" = "{2}"."{3}")", + joinTable, + joinColumnName, + onOtherColumn.tableName, + onOtherColumn.columnName); + return *this; +} + +SqlQueryBuilder& SqlQueryBuilder::InnerJoin(std::string_view joinTable, + std::string_view joinColumnName, + std::string_view onMainTableColumn) +{ + return InnerJoin(joinTable, + joinColumnName, + SqlQualifiedTableColumnName { .tableName = m_query.table, .columnName = onMainTableColumn }); +} + +SqlQueryBuilder& SqlQueryBuilder::Where(std::string_view sqlConditionExpression) +{ + if (m_query.condition.empty()) + m_query.condition += " WHERE "; + else + m_query.condition += " AND "; + + m_query.condition += "("; + m_query.condition += std::string(sqlConditionExpression); + m_query.condition += ")"; + + return *this; +} + +SqlQueryBuilder& SqlQueryBuilder::OrderBy(std::string_view columnName, SqlResultOrdering ordering) +{ + if (m_query.orderBy.empty()) + m_query.orderBy += " ORDER BY "; + else + m_query.orderBy += ", "; + + m_query.orderBy += '"'; + m_query.orderBy += columnName; + m_query.orderBy += '"'; + + if (ordering == SqlResultOrdering::DESCENDING) + m_query.orderBy += " DESC"; + else if (ordering == SqlResultOrdering::ASCENDING) + m_query.orderBy += " ASC"; + return *this; +} + +SqlQueryBuilder& SqlQueryBuilder::GroupBy(std::string_view columnName) +{ + if (m_query.groupBy.empty()) + m_query.groupBy += " GROUP BY "; + else + m_query.groupBy += ", "; + + m_query.groupBy += '"'; + m_query.groupBy += columnName; + m_query.groupBy += '"'; + + return *this; +} + +SqlComposedQuery SqlQueryBuilder::Count() +{ + m_query.type = SqlQueryType::SELECT_COUNT; + + return std::move(m_query); +} + +SqlComposedQuery SqlQueryBuilder::All() +{ + m_query.type = SqlQueryType::SELECT_ALL; + + return std::move(m_query); +} + +SqlComposedQuery SqlQueryBuilder::First(size_t count) +{ + m_query.type = SqlQueryType::SELECT_FIRST; + m_query.limit = count; + + return std::move(m_query); +} + +SqlComposedQuery SqlQueryBuilder::Range(std::size_t offset, std::size_t limit) +{ + m_query.type = SqlQueryType::SELECT_RANGE; + m_query.offset = offset; + m_query.limit = limit; + + return std::move(m_query); +} + +std::string SqlComposedQuery::ToSql(SqlQueryFormatter const& formatter) const +{ + std::string finalConditionBuffer; + std::string const* finalCondition = &condition; + + if (!booleanLiteralConditions.empty()) + { + finalConditionBuffer = condition; + finalCondition = &finalConditionBuffer; + for (auto&& [column, binaryOp, literalValue]: booleanLiteralConditions) + { + if (finalConditionBuffer.empty()) + finalConditionBuffer += " WHERE "; + else + finalConditionBuffer += " AND "; + + finalConditionBuffer += formatter.BooleanWhereClause(column, binaryOp, literalValue); + } + } + + switch (type) + { + case SqlQueryType::UNDEFINED: + break; + case SqlQueryType::SELECT_ALL: + return formatter.SelectAll(fields, table, tableJoins, *finalCondition, orderBy, groupBy); + case SqlQueryType::SELECT_FIRST: + return formatter.SelectFirst(fields, table, tableJoins, *finalCondition, orderBy, limit); + case SqlQueryType::SELECT_RANGE: + return formatter.SelectRange(fields, table, tableJoins, *finalCondition, orderBy, groupBy, offset, limit); + case SqlQueryType::SELECT_COUNT: + return formatter.SelectCount(table, tableJoins, *finalCondition); + } + return ""; +} + +// }}} diff --git a/src/Lightweight/SqlComposedQuery.hpp b/src/Lightweight/SqlComposedQuery.hpp new file mode 100644 index 00000000..fa6efe81 --- /dev/null +++ b/src/Lightweight/SqlComposedQuery.hpp @@ -0,0 +1,214 @@ +#pragma once + +#include "SqlDataBinder.hpp" + +#include +#include +#include +#include +#include +#include +#include + +enum class SqlResultOrdering : uint8_t +{ + ASCENDING, + DESCENDING +}; + +enum class SqlQueryType : uint8_t +{ + UNDEFINED, + + SELECT_ALL, + SELECT_FIRST, + SELECT_RANGE, + SELECT_COUNT, + + // INSERT, + // UPDATE, + // DELETE -- ABUSED by winnt.h on Windows as preprocessor definition. Thanks! +}; + +// SqlQueryWildcard is a placeholder for an explicit wildcard input parameter in a SQL query. +// +// Use this in the SqlQueryBuilder::Where method to insert a '?' placeholder for a wildcard. +struct SqlQueryWildcard +{ +}; + +struct SqlQualifiedTableColumnName +{ + std::string_view tableName; + std::string_view columnName; +}; + +class SqlQueryFormatter; + +struct [[nodiscard]] SqlComposedQuery +{ + SqlQueryType type = SqlQueryType::UNDEFINED; + std::string fields; + std::string table; + std::vector inputBindings; + std::string tableJoins; + std::string condition; + std::vector> + booleanLiteralConditions; + std::string orderBy; + std::string groupBy; + size_t offset = 0; + size_t limit = std::numeric_limits::max(); + + [[nodiscard]] std::string ToSql(SqlQueryFormatter const& formatter) const; +}; + +class [[nodiscard]] SqlQueryBuilder +{ + public: + static SqlQueryBuilder From(std::string_view table); + + // Adds a single column to the SELECT clause. + [[nodiscard]] SqlQueryBuilder& Select(std::vector const& fieldNames); + + // Adds a sequence of columns from the given table to the SELECT clause. + [[nodiscard]] SqlQueryBuilder& Select(std::vector const& fieldNames, std::string_view tableName); + + // Adds a sequence of columns to the SELECT clause. + template + [[nodiscard]] SqlQueryBuilder& Select(std::string_view const& firstField, MoreFields&&... moreFields); + + // Constructs or extends a raw WHERE clause. + [[nodiscard]] SqlQueryBuilder& Where(std::string_view sqlConditionExpression); + + // Constructs or extends a WHERE clause to test for a binary operation. + template + [[nodiscard]] SqlQueryBuilder& Where(ColumnName const& columnName, std::string_view binaryOp, T const& value); + + // Constructs or extends a WHERE clause to test for equality. + template + [[nodiscard]] SqlQueryBuilder& Where(ColumnName const& columnName, T const& value); + + // Constructs or extends a ORDER BY clause. + [[nodiscard]] SqlQueryBuilder& OrderBy(std::string_view columnName, + SqlResultOrdering ordering = SqlResultOrdering::ASCENDING); + + // Constructs or extends a GROUP BY clause. + [[nodiscard]] SqlQueryBuilder& GroupBy(std::string_view columnName); + + // Constructs an INNER JOIN clause. + [[nodiscard]] SqlQueryBuilder& InnerJoin(std::string_view joinTable, + std::string_view joinColumnName, + SqlQualifiedTableColumnName onOtherColumn); + + // Constructs an INNER JOIN clause. + [[nodiscard]] SqlQueryBuilder& InnerJoin(std::string_view joinTable, + std::string_view joinColumnName, + std::string_view onMainTableColumn); + + // final methods + + // Finalizes building the query as SELECT COUNT(*) ... query. + SqlComposedQuery Count(); + + // Finalizes building the query as SELECT field names FROM ... query. + SqlComposedQuery All(); + + // Finalizes building the query as SELECT TOP n field names FROM ... query. + SqlComposedQuery First(size_t count = 1); + + // Finalizes building the query as SELECT field names FROM ... query with a range. + SqlComposedQuery Range(std::size_t offset, std::size_t limit); + + private: + explicit SqlQueryBuilder(std::string_view table); + + SqlComposedQuery m_query {}; +}; + +// {{{ SqlQueryBuilder template implementations and inlines + +template +SqlQueryBuilder& SqlQueryBuilder::Select(std::string_view const& firstField, MoreFields&&... moreFields) +{ + std::ostringstream fragment; + + if (!m_query.fields.empty()) + fragment << ", "; + + fragment << "\"" << firstField << "\""; + + if constexpr (sizeof...(MoreFields) > 0) + ((fragment << ", \"" << std::forward(moreFields) << "\"") << ...); + + m_query.fields += fragment.str(); + return *this; +} + +template +SqlQueryBuilder& SqlQueryBuilder::Where(ColumnName const& columnName, T const& value) +{ + return Where(columnName, "=", value); +} + +template +struct WhereConditionLiteralType +{ + constexpr static bool needsQuotes = !std::is_integral_v && !std::is_floating_point_v; +}; + +template +SqlQueryBuilder& SqlQueryBuilder::Where(ColumnName const& columnName, std::string_view binaryOp, T const& value) +{ + if constexpr (std::is_same_v) + { + if constexpr (std::is_same_v) + m_query.booleanLiteralConditions.emplace_back(columnName, binaryOp, value); + else + m_query.booleanLiteralConditions.emplace_back( + SqlQualifiedTableColumnName { "", columnName }, binaryOp, value); + return *this; + } + + if (m_query.condition.empty()) + m_query.condition += " WHERE "; + else + m_query.condition += " AND "; + + if constexpr (std::is_same_v) + { + m_query.condition += std::format(R"("{}"."{}")", columnName.tableName, columnName.columnName); + } + else + { + m_query.condition += "\""; + m_query.condition += columnName; + m_query.condition += "\""; + } + + m_query.condition += " "; + m_query.condition += binaryOp; + m_query.condition += " "; + + if constexpr (std::is_same_v) + { + m_query.condition += "?"; + m_query.inputBindings.emplace_back(std::monostate()); + } + else if constexpr (!WhereConditionLiteralType::needsQuotes) + { + m_query.condition += std::format("{}", value); + } + else + { + m_query.condition += "'"; + m_query.condition += std::format("{}", value); + m_query.condition += "'"; + // TODO: This should be bound as an input parameter in the future instead. + // m_query.inputBindings.emplace_back(value); + } + + return *this; +} + +// }}} diff --git a/src/Lightweight/SqlConcepts.hpp b/src/Lightweight/SqlConcepts.hpp new file mode 100644 index 00000000..87ce593e --- /dev/null +++ b/src/Lightweight/SqlConcepts.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +// clang-format off +template +concept StdStringViewLike = requires(T const& t, T& u) { + { t.data() } -> std::convertible_to; + { t.size() } -> std::convertible_to; +}; + +template +concept StdStringLike = requires(T const& t, T& u) { + { t.data() } -> std::convertible_to; + { t.size() } -> std::convertible_to; + { u.clear() }; + { u.append(std::declval(), std::declval()) }; +}; + +template +concept MFCStringLike = requires(T const& t) { + { t.GetLength() } -> std::convertible_to; + { t.GetString() } -> std::convertible_to; +}; + +template +concept RNStringLike = requires(T const& t) { + { t.Length() } -> std::convertible_to; + { t.GetString() } -> std::convertible_to; +}; + +// clang-format on diff --git a/src/Lightweight/SqlConnectInfo.cpp b/src/Lightweight/SqlConnectInfo.cpp new file mode 100644 index 00000000..066fbc6c --- /dev/null +++ b/src/Lightweight/SqlConnectInfo.cpp @@ -0,0 +1,74 @@ +#include "SqlConnectInfo.hpp" + +#include +#include +#include +#include + +namespace +{ + +constexpr std::string_view DropQuotation(std::string_view value) noexcept +{ + if (!value.empty() && value.front() == '{' && value.back() == '}') + { + value.remove_prefix(1); + value.remove_suffix(1); + } + return value; +} + +constexpr std::string_view Trim(std::string_view value) noexcept +{ + while (!value.empty() && std::isspace(value.front())) + value.remove_prefix(1); + + while (!value.empty() && std::isspace(value.back())) + value.remove_suffix(1); + + return value; +} + +std::string ToUpperCaseString(std::string_view input) +{ + std::string result { input }; + std::ranges::transform(result, result.begin(), [](char c) { return (char) std::toupper(c); }); + return result; +} + +} // end namespace + +SqlConnectionStringMap ParseConnectionString(SqlConnectionString const& connectionString) +{ + auto pairs = connectionString.value | std::views::split(';') | std::views::transform([](auto pair_view) { + return std::string_view(&*pair_view.begin(), std::ranges::distance(pair_view)); + }); + + SqlConnectionStringMap result; + + for (auto const& pair: pairs) + { + auto separatorPosition = pair.find('='); + if (separatorPosition != std::string_view::npos) + { + auto const key = Trim(pair.substr(0, separatorPosition)); + auto const value = DropQuotation(Trim(pair.substr(separatorPosition + 1))); + result.insert_or_assign(ToUpperCaseString(key), std::string(value)); + } + } + + return result; +} + +SqlConnectionString BuildConnectionString(SqlConnectionStringMap const& map) +{ + SqlConnectionString result; + + for (auto const& [key, value]: map) + { + std::string_view const delimiter = result.value.empty() ? "" : ";"; + result.value += std::format("{}{}={{{}}}", delimiter, key, value); + } + + return result; +} diff --git a/src/Lightweight/SqlConnectInfo.hpp b/src/Lightweight/SqlConnectInfo.hpp new file mode 100644 index 00000000..76ddf15f --- /dev/null +++ b/src/Lightweight/SqlConnectInfo.hpp @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include +#include +#include + +struct SqlConnectionString +{ + std::string value; + + auto operator<=>(SqlConnectionString const&) const noexcept = default; +}; + +using SqlConnectionStringMap = std::map; + +SqlConnectionStringMap ParseConnectionString(SqlConnectionString const& connectionString); +SqlConnectionString BuildConnectionString(SqlConnectionStringMap const& map); + +struct SqlConnectionDataSource +{ + std::string datasource; + std::string username; + std::string password; + std::chrono::seconds timeout { 5 }; + + auto operator<=>(SqlConnectionDataSource const&) const noexcept = default; +}; + +using SqlConnectInfo = std::variant; + +template <> +struct std::formatter: std::formatter +{ + auto format(SqlConnectInfo const& info, format_context& ctx) const -> format_context::iterator + { + if (auto const* dsn = std::get_if(&info)) + { + return formatter::format(std::format("DSN={};UID={};PWD={};TIMEOUT={}", + dsn->datasource, + dsn->username, + dsn->password, + dsn->timeout.count()), + ctx); + } + else if (auto const* connectionString = std::get_if(&info)) + { + return formatter::format(connectionString->value, ctx); + } + else + { + return formatter::format("Invalid connection info", ctx); + } + } +}; diff --git a/src/Lightweight/SqlConnection.cpp b/src/Lightweight/SqlConnection.cpp new file mode 100644 index 00000000..d53bacd6 --- /dev/null +++ b/src/Lightweight/SqlConnection.cpp @@ -0,0 +1,376 @@ +#include "SqlConnection.hpp" +#include "SqlQueryFormatter.hpp" + +#include +#include +#include + +using namespace std::chrono_literals; +using namespace std::string_view_literals; + +static std::list g_unusedConnections; + +class SqlConnectionPool +{ + public: + ~SqlConnectionPool() + { + KillAllIdleConnections(); + } + + void KillAllIdleConnections() + { + auto const _ = std::lock_guard { m_unusedConnectionsMutex }; + for (auto& connection: m_unusedConnections) + connection.Kill(); + m_unusedConnections.clear(); + } + + SqlResult Acquire() + { + auto connection = AcquireDirect(); + if (connection.LastError() != SqlError::SUCCESS) + return std::unexpected { connection.LastError() }; + return { std::move(connection) }; + } + + SqlConnection AcquireDirect() + { + auto const _ = std::lock_guard { m_unusedConnectionsMutex }; + + // Close idle connections + auto const now = std::chrono::steady_clock::now(); + while (!m_unusedConnections.empty() && now - m_unusedConnections.front().LastUsed() > m_connectionTimeout) + { + ++m_stats.timedout; + SqlLogger::GetLogger().OnConnectionIdle(m_unusedConnections.front()); + m_unusedConnections.front().Kill(); + m_unusedConnections.pop_front(); + } + + // Reuse an existing connection + if (!m_unusedConnections.empty()) + { + ++m_stats.reused; + auto connection = std::move(m_unusedConnections.front()); + m_unusedConnections.pop_front(); + SqlLogger::GetLogger().OnConnectionReuse(connection); + return connection; + } + + // Create a new connection + ++m_stats.created; + auto connection = SqlConnection { SqlConnection::DefaultConnectInfo() }; + return connection; + } + + void Release(SqlConnection&& connection) + { + auto const _ = std::lock_guard { m_unusedConnectionsMutex }; + ++m_stats.released; + if (m_unusedConnections.size() < m_maxIdleConnections) + { + connection.SetLastUsed(std::chrono::steady_clock::now()); + SqlLogger::GetLogger().OnConnectionReuse(connection); + m_unusedConnections.emplace_back(std::move(connection)); + } + else + { + SqlLogger::GetLogger().OnConnectionIdle(connection); + connection.Kill(); + } + } + + void SetMaxIdleConnections(size_t maxIdleConnections) noexcept + { + m_maxIdleConnections = maxIdleConnections; + } + + [[nodiscard]] SqlConnectionStats Stats() const noexcept + { + return m_stats; + } + + private: + std::list m_unusedConnections; + std::mutex m_unusedConnectionsMutex; + size_t m_maxIdleConnections = 10; + std::chrono::seconds m_connectionTimeout = std::chrono::seconds { 120 }; + SqlConnectionStats m_stats; +}; + +static SqlConnectionPool g_connectionPool; + +// ===================================================================================================================== + +SqlConnection::SqlConnection() noexcept: + SqlConnection(g_connectionPool.AcquireDirect()) +{ +} + +SqlConnection::SqlConnection(SqlConnectInfo const& connectInfo) noexcept +{ + SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &m_hEnv); + SQLSetEnvAttr(m_hEnv, SQL_ATTR_ODBC_VERSION, (SQLPOINTER) SQL_OV_ODBC3, 0); + SQLAllocHandle(SQL_HANDLE_DBC, m_hEnv, &m_hDbc); + Connect(connectInfo); +} + +SqlConnection::SqlConnection(SqlConnection&& other) noexcept: + m_hEnv { other.m_hEnv }, + m_hDbc { other.m_hDbc }, + m_connectionId { other.m_connectionId }, + m_lastError { other.m_lastError }, + m_connectInfo { std::move(other.m_connectInfo) }, + m_lastUsed { other.m_lastUsed }, + m_serverType { other.m_serverType }, + m_queryFormatter { other.m_queryFormatter } +{ + other.m_hEnv = {}; + other.m_hDbc = {}; +} + +SqlConnection& SqlConnection::operator=(SqlConnection&& other) noexcept +{ + if (this == &other) + return *this; + + Close(); + + m_hEnv = other.m_hEnv; + m_hDbc = other.m_hDbc; + m_connectionId = other.m_connectionId; + m_lastError = other.m_lastError; + m_connectInfo = std::move(other.m_connectInfo); + m_lastUsed = other.m_lastUsed; + + other.m_hEnv = {}; + other.m_hDbc = {}; + + return *this; +} + +SqlConnection::~SqlConnection() noexcept +{ + Close(); +} + +void SqlConnection::SetMaxIdleConnections(size_t maxIdleConnections) noexcept +{ + g_connectionPool.SetMaxIdleConnections(maxIdleConnections); +} + +void SqlConnection::KillAllIdle() +{ + g_connectionPool.KillAllIdleConnections(); +} + +void SqlConnection::SetPostConnectedHook(std::function hook) +{ + m_gPostConnectedHook = std::move(hook); +} + +void SqlConnection::ResetPostConnectedHook() +{ + m_gPostConnectedHook = {}; +} + +SqlConnectionStats SqlConnection::Stats() noexcept +{ + return g_connectionPool.Stats(); +} + +bool SqlConnection::Connect(std::string_view datasource, std::string_view username, std::string_view password) noexcept +{ + return Connect(SqlConnectionDataSource { + .datasource = std::string(datasource), .username = std::string(username), .password = std::string(password) }); +} + +// Connects to the given database with the given ODBC connection string. +bool SqlConnection::Connect(std::string connectionString) noexcept +{ + return Connect(SqlConnectionString { .value = std::move(connectionString) }); +} + +void SqlConnection::PostConnect() +{ + auto const mappings = std::array { + std::pair { "Microsoft SQL Server"sv, SqlServerType::MICROSOFT_SQL }, + std::pair { "PostgreSQL"sv, SqlServerType::POSTGRESQL }, + std::pair { "Oracle"sv, SqlServerType::ORACLE }, + std::pair { "SQLite"sv, SqlServerType::SQLITE }, + std::pair { "MySQL"sv, SqlServerType::MYSQL }, + }; + + auto const serverName = ServerName(); + for (auto const& [name, type]: mappings) + { + if (serverName.contains(name)) + { + m_serverType = type; + break; + } + } + + m_queryFormatter = SqlQueryFormatter::Get(m_serverType); +} + +// Connects to the given database with the given username and password. +bool SqlConnection::Connect(SqlConnectInfo connectInfo) noexcept +{ + m_connectInfo = std::move(connectInfo); + + if (auto const* info = std::get_if(&m_connectInfo)) + { + UpdateLastError(SQLSetConnectAttrA(m_hDbc, SQL_LOGIN_TIMEOUT, (SQLPOINTER) info->timeout.count(), 0)); + return UpdateLastError(SQLConnectA(m_hDbc, + (SQLCHAR*) info->datasource.data(), + (SQLSMALLINT) info->datasource.size(), + (SQLCHAR*) info->username.data(), + (SQLSMALLINT) info->username.size(), + (SQLCHAR*) info->password.data(), + (SQLSMALLINT) info->password.size())) + .and_then([&] { + return UpdateLastError( + SQLSetConnectAttrA(m_hDbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER) SQL_AUTOCOMMIT_ON, SQL_IS_UINTEGER)); + }) + .and_then([&]() -> SqlResult { + PostConnect(); + SqlLogger::GetLogger().OnConnectionOpened(*this); + if (m_gPostConnectedHook) + m_gPostConnectedHook(*this); + return {}; + }) + .or_else([&](auto&&) -> SqlResult { + SqlLogger::GetLogger().OnError(m_lastError, SqlErrorInfo::fromConnectionHandle(m_hDbc)); + return std::unexpected { m_lastError }; + }) + .has_value(); + } + + auto const& connectionString = std::get(m_connectInfo).value; + + return UpdateLastError(SQLDriverConnectA(m_hDbc, + (SQLHWND) nullptr, + (SQLCHAR*) connectionString.data(), + (SQLSMALLINT) connectionString.size(), + nullptr, + 0, + nullptr, + SQL_DRIVER_NOPROMPT)) + .and_then([&] { + return UpdateLastError( + SQLSetConnectAttrA(m_hDbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER) SQL_AUTOCOMMIT_ON, SQL_IS_UINTEGER)); + }) + .and_then([&]() -> SqlResult { + PostConnect(); + SqlLogger::GetLogger().OnConnectionOpened(*this); + if (m_gPostConnectedHook) + m_gPostConnectedHook(*this); + return {}; + }) + .has_value(); +} + +void SqlConnection::Close() noexcept +{ + if (!m_hDbc) + return; + + if (m_connectInfo == DefaultConnectInfo()) + g_connectionPool.Release(std::move(*this)); + else + Kill(); +} + +void SqlConnection::Kill() noexcept +{ + if (!m_hDbc) + return; + + SqlLogger::GetLogger().OnConnectionClosed(*this); + + SQLDisconnect(m_hDbc); + SQLFreeHandle(SQL_HANDLE_DBC, m_hDbc); + SQLFreeHandle(SQL_HANDLE_ENV, m_hEnv); + + m_hDbc = {}; + m_hEnv = {}; +} + +std::string SqlConnection::DatabaseName() const +{ + std::string name(128, '\0'); + SQLSMALLINT nameLen {}; + RequireSuccess(SQLGetInfoA(m_hDbc, SQL_DATABASE_NAME, name.data(), (SQLSMALLINT) name.size(), &nameLen)); + name.resize(nameLen); + return name; +} + +std::string SqlConnection::UserName() const +{ + std::string name(128, '\0'); + SQLSMALLINT nameLen {}; + RequireSuccess(SQLGetInfoA(m_hDbc, SQL_USER_NAME, name.data(), (SQLSMALLINT) name.size(), &nameLen)); + name.resize(nameLen); + return name; +} + +std::string SqlConnection::ServerName() const +{ + std::string name(128, '\0'); + SQLSMALLINT nameLen {}; + RequireSuccess(SQLGetInfoA(m_hDbc, SQL_DBMS_NAME, (SQLPOINTER) name.data(), (SQLSMALLINT) name.size(), &nameLen)); + name.resize(nameLen); + return name; +} + +std::string SqlConnection::ServerVersion() const +{ + std::string text(128, '\0'); + SQLSMALLINT textLen {}; + RequireSuccess(SQLGetInfoA(m_hDbc, SQL_DBMS_VER, (SQLPOINTER) text.data(), (SQLSMALLINT) text.size(), &textLen)); + text.resize(textLen); + return text; +} + +bool SqlConnection::TransactionActive() const noexcept +{ + SQLUINTEGER state {}; + UpdateLastError(SQLGetConnectAttrA(m_hDbc, SQL_ATTR_AUTOCOMMIT, &state, 0, nullptr)); + return m_lastError == SqlError::SUCCESS && state == SQL_AUTOCOMMIT_OFF; +} + +bool SqlConnection::TransactionsAllowed() const noexcept +{ + SQLUSMALLINT txn {}; + SQLSMALLINT t {}; + SQLRETURN const rv = SQLGetInfo(m_hDbc, (SQLUSMALLINT) SQL_TXN_CAPABLE, &txn, sizeof(txn), &t); + return rv == SQL_SUCCESS && txn != SQL_TC_NONE; +} + +bool SqlConnection::IsAlive() const noexcept +{ + SQLUINTEGER state {}; + UpdateLastError(SQLGetConnectAttrA(m_hDbc, SQL_ATTR_CONNECTION_DEAD, &state, 0, nullptr)); + return m_lastError == SqlError::SUCCESS && state == SQL_CD_FALSE; +} + +void SqlConnection::RequireSuccess(SQLRETURN error, std::source_location sourceLocation) const +{ + auto result = detail::UpdateSqlError(&m_lastError, error); + if (result.has_value()) + return; + + auto errorInfo = SqlErrorInfo::fromConnectionHandle(m_hDbc); + SqlLogger::GetLogger().OnError(m_lastError, errorInfo, sourceLocation); + throw std::runtime_error(std::format("SQL error: {}", errorInfo)); +} + +SqlResult SqlConnection::UpdateLastError(SQLRETURN error, std::source_location sourceLocation) const noexcept +{ + return detail::UpdateSqlError(&m_lastError, error).or_else([&](auto&&) -> SqlResult { + SqlLogger::GetLogger().OnError(m_lastError, SqlErrorInfo::fromConnectionHandle(m_hDbc), sourceLocation); + return std::unexpected { m_lastError }; + }); +} diff --git a/src/Lightweight/SqlConnection.hpp b/src/Lightweight/SqlConnection.hpp new file mode 100644 index 00000000..497fe5ca --- /dev/null +++ b/src/Lightweight/SqlConnection.hpp @@ -0,0 +1,198 @@ +#pragma once + +#if defined(_WIN32) || defined(_WIN64) + #include +#endif + +#include "SqlConcepts.hpp" +#include "SqlConnectInfo.hpp" +#include "SqlError.hpp" +#include "SqlLogger.hpp" +#include "SqlTraits.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +class SqlQueryFormatter; + +// @brief Represents a connection to a SQL database. +class SqlConnection final +{ + public: + // Constructs a new SQL connection to the default connection. + // + // The default connection is set via SetDefaultConnectInfo. + // In case the default connection is not set, the connection will fail. + // And in case the connection fails, the last error will be set. + SqlConnection() noexcept; + + // Constructs a new SQL connection to the given connect informaton. + explicit SqlConnection(SqlConnectInfo const& connectInfo) noexcept; + + SqlConnection(SqlConnection&&) noexcept; + SqlConnection& operator=(SqlConnection&&) noexcept; + SqlConnection(SqlConnection const&) = delete; + SqlConnection& operator=(SqlConnection const&) = delete; + + // Destructs this SQL connection object, + ~SqlConnection() noexcept; + + // Retrieves the default connection information. + static SqlConnectInfo const& DefaultConnectInfo() noexcept + { + return m_gDefaultConnectInfo.value(); + } + + // Sets the default connection information. + static void SetDefaultConnectInfo(SqlConnectInfo connectInfo) noexcept + { + m_gDefaultConnectInfo = std::move(connectInfo); + } + + // Sets the maximum number of idle connections in the connection pool. + static void SetMaxIdleConnections(size_t maxIdleConnections) noexcept; + + // Kills all idle connections in the connection pool. + static void KillAllIdle(); + + static void SetPostConnectedHook(std::function hook); + static void ResetPostConnectedHook(); + + static SqlConnectionStats Stats() noexcept; + + // Retrieves the connection ID. + // + // This is a unique identifier for the connection, which is useful for debugging purposes. + // Note, this ID will not change if the connection is moved nor when it is reused via the connection pool. + [[nodiscard]] uint64_t ConnectionId() const noexcept + { + return m_connectionId; + } + + // Closes the connection (attempting to put it back into the connection pool). + void Close() noexcept; + + // Kills the connection. + void Kill() noexcept; + + // Connects to the given database with the given username and password. + bool Connect(std::string_view datasource, std::string_view username, std::string_view password) noexcept; + + // Connects to the given database with the given ODBC connection string. + bool Connect(std::string connectionString) noexcept; + + // Connects to the given database with the given username and password. + bool Connect(SqlConnectInfo connectInfo) noexcept; + + // Retrieves the name of the database in use. + [[nodiscard]] std::string DatabaseName() const; + + // Retrieves the name of the user. + [[nodiscard]] std::string UserName() const; + + // Retrieves the name of the server. + [[nodiscard]] std::string ServerName() const; + + // Retrieves the reported server version. + [[nodiscard]] std::string ServerVersion() const; + + // Retrieves the type of the server. + [[nodiscard]] SqlServerType ServerType() const noexcept; + + // Retrieves a query formatter suitable for the SQL server being connected. + [[nodiscard]] SqlQueryFormatter const& QueryFormatter() const noexcept; + + // Retrieves the SQL traits for the server. + [[nodiscard]] SqlTraits const& Traits() const noexcept + { + return GetSqlTraits(ServerType()); + } + + // Tests if a transaction is active. + [[nodiscard]] bool TransactionActive() const noexcept; + + // Tests if transactions are allowed. + [[nodiscard]] bool TransactionsAllowed() const noexcept; + + // Tests if the connection is still active. + [[nodiscard]] bool IsAlive() const noexcept; + + // Retrieves the connection information. + [[nodiscard]] SqlConnectInfo const& ConnectionInfo() const noexcept + { + return m_connectInfo; + } + + // Retrieves the native handle. + [[nodiscard]] SQLHDBC NativeHandle() const noexcept + { + return m_hDbc; + } + + // Retrieves the last error code. + [[nodiscard]] SqlError LastError() const noexcept + { + return m_lastError; + } + + // Retrieves the last time the connection was used. + [[nodiscard]] std::chrono::steady_clock::time_point LastUsed() const noexcept + { + return m_lastUsed; + } + + // Sets the last time the connection was used. + void SetLastUsed(std::chrono::steady_clock::time_point lastUsed) noexcept + { + m_lastUsed = lastUsed; + } + + private: + void PostConnect(); + + void RequireSuccess(SQLRETURN error, std::source_location sourceLocation = std::source_location::current()) const; + + // Updates the last error code and returns the error code as an SqlResult if the operation failed. + // + // We also log here the error message. + SqlResult UpdateLastError( + SQLRETURN error, std::source_location sourceLocation = std::source_location::current()) const noexcept; + + // Private data members + + static inline std::optional m_gDefaultConnectInfo; + static inline std::atomic m_gNextConnectionId { 1 }; + static inline std::function m_gPostConnectedHook {}; + + SQLHENV m_hEnv {}; + SQLHDBC m_hDbc {}; + uint64_t m_connectionId { m_gNextConnectionId++ }; + mutable SqlError m_lastError {}; + SqlConnectInfo m_connectInfo; + std::chrono::steady_clock::time_point m_lastUsed; // Last time the connection was used (mostly interesting for + // idle connections in the connection pool). + SqlServerType m_serverType = SqlServerType::UNKNOWN; + SqlQueryFormatter const* m_queryFormatter {}; +}; + +inline SqlServerType SqlConnection::ServerType() const noexcept +{ + return m_serverType; +} + +inline SqlQueryFormatter const& SqlConnection::QueryFormatter() const noexcept +{ + return *m_queryFormatter; +} diff --git a/src/Lightweight/SqlDataBinder.hpp b/src/Lightweight/SqlDataBinder.hpp new file mode 100644 index 00000000..9ae92640 --- /dev/null +++ b/src/Lightweight/SqlDataBinder.hpp @@ -0,0 +1,1196 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#if defined(_WIN32) || defined(_WIN64) + #include +#endif + +#include "SqlLogger.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace detail +{ + +template +constexpr Integer toInteger(std::string_view s, Integer fallback) noexcept +{ + Integer value {}; + auto const rc = std::from_chars(s.data(), s.data() + s.size(), value); +#if __cpp_lib_to_chars >= 202306L + if (rc) +#else + if (rc.ec == std::errc {}) +#endif + return value; + else + return fallback; +} + +} // namespace detail + +// clang-format off +#if !defined(SQL_SS_TIME2) +// This is a Microsoft-specific extension to ODBC. +// It is supported by at lesat the following drivers: +// - SQL Server 2008 and later +// - MariaDB and MySQL ODBC drivers + +#define SQL_SS_TIME2 (-154) + +struct SQL_SS_TIME2_STRUCT +{ + SQLUSMALLINT hour; + SQLUSMALLINT minute; + SQLUSMALLINT second; + SQLUINTEGER fraction; +}; + +static_assert( + sizeof(SQL_SS_TIME2_STRUCT) == 12, + "SQL_SS_TIME2_STRUCT size must be padded 12 bytes, as per ODBC extension spec." +); + +#endif +// clang-format on + +// Helper struct to store a string that should be automatically trimmed when fetched from the database. +// This is only needed for compatibility with old columns that hard-code the length, like CHAR(50). +struct SqlTrimmedString +{ + std::string value; + + std::weak_ordering operator<=>(SqlTrimmedString const&) const noexcept = default; +}; + +template <> +struct std::formatter: std::formatter +{ + auto format(SqlTrimmedString const& text, format_context& ctx) const -> format_context::iterator + { + return std::formatter::format(text.value, ctx); + } +}; + +// Represents a TEXT field in a SQL database. +// +// This is used for large texts, e.g. up to 65k characters. +struct SqlText +{ + using value_type = std::string; + + value_type value; + + std::weak_ordering operator<=>(SqlText const&) const noexcept = default; +}; + +enum class SqlStringPostRetrieveOperation : uint8_t +{ + NOTHING, + TRIM_RIGHT, +}; + +// SQL fixed-capacity string that mimmicks standard library string/string_view with a fixed-size underlying +// buffer. +// +// The underlying storage will not be guaranteed to be `\0`-terminated unless +// a call to mutable/const c_str() has been performed. +template +class SqlFixedString +{ + private: + T _data[N + 1] {}; + std::size_t _size = 0; + + public: + using value_type = T; + using iterator = T*; + using const_iterator = T const*; + using pointer_type = T*; + using const_pointer_type = T const*; + + static constexpr std::size_t Capacity = N; + static constexpr SqlStringPostRetrieveOperation PostRetrieveOperation = PostOp; + + template + SqlFixedString(T const (&text)[SourceSize]): + _size { SourceSize - 1 } + { + static_assert(SourceSize <= N + 1, "RHS string size must not exceed target string's capacity."); + std::copy_n(text, SourceSize, _data); + } + + SqlFixedString() = default; + SqlFixedString(SqlFixedString const&) = default; + SqlFixedString& operator=(SqlFixedString const&) = default; + SqlFixedString(SqlFixedString&&) = default; + SqlFixedString& operator=(SqlFixedString&&) = default; + ~SqlFixedString() = default; + + void reserve(std::size_t capacity) + { + if (capacity > N) + throw std::length_error( + std::format("SqlFixedString: capacity {} exceeds maximum capacity {}", capacity, N)); + } + + [[nodiscard]] constexpr bool empty() const noexcept + { + return _size == 0; + } + + [[nodiscard]] constexpr std::size_t size() const noexcept + { + return _size; + } + + constexpr void setsize(std::size_t n) noexcept + { + auto const newSize = (std::min)(n, N); + _size = newSize; + } + + constexpr void resize(std::size_t n, T c = T {}) noexcept + { + auto const newSize = (std::min)(n, N); + if (newSize > _size) + std::fill_n(end(), newSize - _size, c); + _size = newSize; + } + + [[nodiscard]] constexpr std::size_t capacity() const noexcept + { + return N; + } + + constexpr void clear() noexcept + { + _size = 0; + } + + template + constexpr void assign(T const (&source)[SourceSize]) noexcept + { + static_assert(SourceSize <= N + 1, "Source string must not overflow the target string's capacity."); + _size = SourceSize - 1; + std::copy_n(source, SourceSize, _data); + } + + constexpr void assign(std::string_view s) noexcept + { + _size = (std::min)(N, s.size()); + std::copy_n(s.data(), _size, _data); + } + + constexpr void push_back(T c) noexcept + { + if (_size < N) + { + _data[_size] = c; + ++_size; + } + } + + constexpr void pop_back() noexcept + { + if (_size > 0) + --_size; + } + + constexpr std::basic_string_view substr( + std::size_t offset = 0, std::size_t count = (std::numeric_limits::max)()) const noexcept + { + if (offset >= _size) + return {}; + if (count == (std::numeric_limits::max)()) + return std::basic_string_view(_data + offset, _size - offset); + if (offset + count > _size) + return std::basic_string_view(_data + offset, _size - offset); + return std::basic_string_view(_data + offset, count); + } + + // clang-format off + constexpr pointer_type c_str() noexcept { _data[_size] = '\0'; return _data; } + constexpr pointer_type data() noexcept { return _data; } + constexpr iterator begin() noexcept { return _data; } + constexpr iterator end() noexcept { return _data + size(); } + constexpr T& at(std::size_t i) noexcept { return _data[i]; } + constexpr T& operator[](std::size_t i) noexcept { return _data[i]; } + + constexpr const_pointer_type c_str() const noexcept { const_cast(_data)[_size] = '\0'; return _data; } + constexpr const_pointer_type data() const noexcept { return _data; } + constexpr const_iterator begin() const noexcept { return _data; } + constexpr const_iterator end() const noexcept { return _data + size(); } + constexpr T const& at(std::size_t i) const noexcept { return _data[i]; } + constexpr T const& operator[](std::size_t i) const noexcept { return _data[i]; } + // clang-format on + + template + std::weak_ordering operator<=>(SqlFixedString const& other) const noexcept + { + if ((void*) this != (void*) &other) + { + for (std::size_t i = 0; i < (std::min)(N, OtherSize); ++i) + if (auto const cmp = _data[i] <=> other._data[i]; cmp != std::weak_ordering::equivalent) + return cmp; + if constexpr (N != OtherSize) + return N <=> OtherSize; + } + return std::weak_ordering::equivalent; + } + + template + constexpr bool operator==(SqlFixedString const& other) const noexcept + { + return (*this <=> other) == std::weak_ordering::equivalent; + } + + template + constexpr bool operator!=(SqlFixedString const& other) const noexcept + { + return !(*this == other); + } + + constexpr bool operator==(std::string_view other) const noexcept + { + return (substr() <=> other) == std::weak_ordering::equivalent; + } + + constexpr bool operator!=(std::string_view other) const noexcept + { + return !(*this == other); + } +}; + +template +using SqlTrimmedFixedString = SqlFixedString; + +template +struct std::formatter>: std::formatter +{ + using value_type = SqlFixedString; + auto format(value_type const& text, format_context& ctx) const -> format_context::iterator + { + return std::formatter::format(text.c_str(), ctx); + } +}; + +template <> +struct std::formatter: std::formatter +{ + auto format(SqlText const& text, format_context& ctx) const -> format_context::iterator + { + return std::formatter::format(text.value, ctx); + } +}; + +// Helper struct to store a date (without time of the day) to write to or read from a database. +struct SqlDate +{ + SQL_DATE_STRUCT sqlValue {}; + + SqlDate() noexcept = default; + SqlDate(SqlDate&&) noexcept = default; + SqlDate& operator=(SqlDate&&) noexcept = default; + SqlDate(SqlDate const&) noexcept = default; + SqlDate& operator=(SqlDate const&) noexcept = default; + ~SqlDate() noexcept = default; + + [[nodiscard]] std::chrono::year_month_day value() const noexcept + { + return ConvertToNative(sqlValue); + } + + bool operator==(SqlDate const& other) const noexcept + { + return sqlValue.year == other.sqlValue.year && sqlValue.month == other.sqlValue.month + && sqlValue.day == other.sqlValue.day; + } + + bool operator!=(SqlDate const& other) const noexcept + { + return !(*this == other); + } + + SqlDate(std::chrono::year_month_day value) noexcept: + sqlValue { SqlDate::ConvertToSqlValue(value) } + { + } + + SqlDate(std::chrono::year year, std::chrono::month month, std::chrono::day day) noexcept: + SqlDate(std::chrono::year_month_day { year, month, day }) + { + } + + static SqlDate Today() noexcept + { + return SqlDate { std::chrono::year_month_day { + std::chrono::floor(std::chrono::system_clock::now()), + } }; + } + + static SQL_DATE_STRUCT ConvertToSqlValue(std::chrono::year_month_day value) noexcept + { + return SQL_DATE_STRUCT { + .year = (SQLSMALLINT) (int) value.year(), + .month = (SQLUSMALLINT) (unsigned) value.month(), + .day = (SQLUSMALLINT) (unsigned) value.day(), + }; + } + + static std::chrono::year_month_day ConvertToNative(SQL_DATE_STRUCT const& value) noexcept + { + return std::chrono::year_month_day { std::chrono::year { value.year }, + std::chrono::month { static_cast(value.month) }, + std::chrono::day { static_cast(value.day) } }; + } +}; + +// Helper struct to store a time (of the day) to write to or read from a database. +struct SqlTime +{ + using native_type = std::chrono::hh_mm_ss; + +#if defined(SQL_SS_TIME2) + using sql_type = SQL_SS_TIME2_STRUCT; +#else + using sql_type = SQL_TIME_STRUCT; +#endif + + sql_type sqlValue {}; + + SqlTime() noexcept = default; + SqlTime(SqlTime&&) noexcept = default; + SqlTime& operator=(SqlTime&&) noexcept = default; + SqlTime(SqlTime const&) noexcept = default; + SqlTime& operator=(SqlTime const&) noexcept = default; + ~SqlTime() noexcept = default; + + [[nodiscard]] native_type value() const noexcept + { + return ConvertToNative(sqlValue); + } + + bool operator==(SqlTime const& other) const noexcept + { + return value().to_duration().count() == other.value().to_duration().count(); + } + + bool operator!=(SqlTime const& other) const noexcept + { + return !(*this == other); + } + + SqlTime(native_type value) noexcept: + sqlValue { SqlTime::ConvertToSqlValue(value) } + { + } + + SqlTime(std::chrono::hours hour, + std::chrono::minutes minute, + std::chrono::seconds second, + std::chrono::microseconds micros = {}) noexcept: + SqlTime(native_type { hour + minute + second + micros }) + { + } + + static sql_type ConvertToSqlValue(native_type value) noexcept + { + return sql_type { + .hour = (SQLUSMALLINT) value.hours().count(), + .minute = (SQLUSMALLINT) value.minutes().count(), + .second = (SQLUSMALLINT) value.seconds().count(), +#if defined(SQL_SS_TIME2) + .fraction = (SQLUINTEGER) value.subseconds().count(), +#endif + }; + } + + static native_type ConvertToNative(sql_type const& value) noexcept + { + // clang-format off + return native_type { std::chrono::hours { (int) value.hour } + + std::chrono::minutes { (unsigned) value.minute } + + std::chrono::seconds { (unsigned) value.second } +#if defined(SQL_SS_TIME2) + + std::chrono::microseconds { value.fraction } +#endif + + }; + // clang-format on + } +}; + +struct SqlDateTime +{ + using native_type = std::chrono::time_point; + + static SqlDateTime Now() noexcept + { + return SqlDateTime { std::chrono::system_clock::now() }; + } + + SqlDateTime() noexcept = default; + SqlDateTime(SqlDateTime&&) noexcept = default; + SqlDateTime& operator=(SqlDateTime&&) noexcept = default; + SqlDateTime(SqlDateTime const&) noexcept = default; + SqlDateTime& operator=(SqlDateTime const& other) noexcept = default; + ~SqlDateTime() noexcept = default; + + bool operator==(SqlDateTime const& other) const noexcept + { + return value() == other.value(); + } + + bool operator!=(SqlDateTime const& other) const noexcept + { + return !(*this == other); + } + + SqlDateTime(std::chrono::year_month_day ymd, std::chrono::hh_mm_ss time) noexcept + { + sqlValue.year = (SQLSMALLINT) (int) ymd.year(); + sqlValue.month = (SQLUSMALLINT) (unsigned) ymd.month(); + sqlValue.day = (SQLUSMALLINT) (unsigned) ymd.day(); + sqlValue.hour = (SQLUSMALLINT) time.hours().count(); + sqlValue.minute = (SQLUSMALLINT) time.minutes().count(); + sqlValue.second = (SQLUSMALLINT) time.seconds().count(); + sqlValue.fraction = (SQLUINTEGER) (time.subseconds().count() / 100) * 100; + } + + SqlDateTime(std::chrono::year year, + std::chrono::month month, + std::chrono::day day, + std::chrono::hours hour, + std::chrono::minutes minute, + std::chrono::seconds second, + std::chrono::nanoseconds nanosecond = std::chrono::nanoseconds { 0 }) noexcept + { + sqlValue.year = (SQLSMALLINT) (int) year; + sqlValue.month = (SQLUSMALLINT) (unsigned) month; + sqlValue.day = (SQLUSMALLINT) (unsigned) day; + sqlValue.hour = (SQLUSMALLINT) hour.count(); + sqlValue.minute = (SQLUSMALLINT) minute.count(); + sqlValue.second = (SQLUSMALLINT) second.count(); + sqlValue.fraction = (SQLUINTEGER) (nanosecond.count() / 100) * 100; + } + + SqlDateTime(std::chrono::system_clock::time_point value) noexcept: + sqlValue { SqlDateTime::ConvertToSqlValue(value) } + { + } + + operator native_type() const noexcept + { + return value(); + } + + static SQL_TIMESTAMP_STRUCT ConvertToSqlValue(native_type value) noexcept + { + using namespace std::chrono; + auto const totalDays = floor(value); + auto const ymd = year_month_day { totalDays }; + auto const hms = hh_mm_ss { floor(value - totalDays) }; + + return SQL_TIMESTAMP_STRUCT { + .year = (SQLSMALLINT) (int) ymd.year(), + .month = (SQLUSMALLINT) (unsigned) ymd.month(), + .day = (SQLUSMALLINT) (unsigned) ymd.day(), + .hour = (SQLUSMALLINT) hms.hours().count(), + .minute = (SQLUSMALLINT) hms.minutes().count(), + .second = (SQLUSMALLINT) hms.seconds().count(), + .fraction = (SQLUINTEGER) (hms.subseconds().count() / 100) * 100, + }; + } + + static native_type ConvertToNative(SQL_TIMESTAMP_STRUCT const& time) noexcept + { + // clang-format off + using namespace std::chrono; + auto timepoint = sys_days(year_month_day(year(time.year), month(time.month), day(time.day))) + + hours(time.hour) + + minutes(time.minute) + + seconds(time.second) + + nanoseconds(time.fraction); + return timepoint; + // clang-format on + } + + [[nodiscard]] native_type value() const noexcept + { + return ConvertToNative(sqlValue); + } + + SQL_TIMESTAMP_STRUCT sqlValue {}; +}; + +// Helper struct to store a timestamp that should be automatically converted to/from a SQL_TIMESTAMP_STRUCT. +// Helper struct to generically store and load a variant of different SQL types. +using SqlVariant = std::variant; + +// Callback interface for SqlDataBinder to allow post-processing of output columns. +// +// This is needed because the SQLBindCol() function does not allow to specify a callback function to be called +// after the data has been fetched from the database. This is needed to trim strings to the correct size, for +// example. +class SqlDataBinderCallback +{ + public: + virtual ~SqlDataBinderCallback() = default; + + virtual void PlanPostExecuteCallback(std::function&&) = 0; + virtual void PlanPostProcessOutputColumn(std::function&&) = 0; +}; + +template +struct SqlDataBinder +{ + static_assert(false, "No SQL data binder available for this type."); +}; + +// clang-format off +template +struct SqlSimpleDataBinder +{ + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, T const& value) noexcept + { + return SQLBindParameter(stmt, column, SQL_PARAM_INPUT, TheCType, TheSqlType, 0, 0, (SQLPOINTER) &value, 0, nullptr); + } + + static SQLRETURN OutputColumn(SQLHSTMT stmt, SQLUSMALLINT column, T* result, SQLLEN* indicator, SqlDataBinderCallback&) noexcept + { + return SQLBindCol(stmt, column, TheCType, result, 0, indicator); + } + + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLUSMALLINT column, T* result, SQLLEN* indicator) noexcept + { + return SQLGetData(stmt, column, TheCType, result, 0, indicator); + } +}; + +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +#if !defined(_WIN32) && !defined(__APPLE__) +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +#endif +#if defined(__APPLE__) // size_t is a different type on macOS +template <> struct SqlDataBinder: SqlSimpleDataBinder {}; +#endif +// clang-format on + +// Default traits for output string parameters +// This needs to be implemented for each string type that should be used as output parameter via +// SqlDataBinder<>. An std::string specialization is provided below. Feel free to add more specializations for +// other string types, such as CString, etc. +template +struct SqlOutputStringTraits; + +// Specialized traits for std::string as output string parameter +template <> +struct SqlOutputStringTraits +{ + static char const* Data(std::string const* str) noexcept + { + return str->data(); + } + static char* Data(std::string* str) noexcept + { + return str->data(); + } + static SQLULEN Size(std::string const* str) noexcept + { + return str->size(); + } + static void Clear(std::string* str) noexcept + { + str->clear(); + } + + static void Reserve(std::string* str, size_t capacity) noexcept + { + // std::string tries to defer the allocation as long as possible. + // So we first tell std::string how much to reserve and then resize it to the *actually* reserved + // size. + str->reserve(capacity); + str->resize(str->capacity()); + } + + static void Resize(std::string* str, SQLLEN indicator) noexcept + { + if (indicator > 0) + str->resize(indicator); + } +}; + +template <> +struct SqlOutputStringTraits +{ + using Traits = SqlOutputStringTraits; + + // clang-format off + static char const* Data(SqlText const* str) noexcept { return Traits::Data(&str->value); } + static char* Data(SqlText* str) noexcept { return Traits::Data(&str->value); } + static SQLULEN Size(SqlText const* str) noexcept { return Traits::Size(&str->value); } + static void Clear(SqlText* str) noexcept { Traits::Clear(&str->value); } + static void Reserve(SqlText* str, size_t capacity) noexcept { Traits::Reserve(&str->value, capacity); } + static void Resize(SqlText* str, SQLLEN indicator) noexcept { Traits::Resize(&str->value, indicator); } + // clang-format on +}; + +template +struct SqlOutputStringTraits> +{ + using ValueType = SqlFixedString; + // clang-format off + static char const* Data(ValueType const* str) noexcept { return str->data(); } + static char* Data(ValueType* str) noexcept { return str->data(); } + static SQLULEN Size(ValueType const* str) noexcept { return str->size(); } + static void Clear(ValueType* str) noexcept { str->clear(); } + static void Reserve(ValueType* str, size_t capacity) noexcept { str->reserve(capacity); } + static void Resize(ValueType* str, SQLLEN indicator) noexcept { str->resize(indicator); } + // clang-format on +}; + +// clang-format off +template +concept SqlOutputStringTraitsConcept = requires(StringType* str) { + { SqlOutputStringTraits::Data(str) } -> std::same_as; + { SqlOutputStringTraits::Size(str) } -> std::same_as; + { SqlOutputStringTraits::Reserve(str, size_t {}) } -> std::same_as; + { SqlOutputStringTraits::Resize(str, SQLLEN {}) } -> std::same_as; + { SqlOutputStringTraits::Clear(str) } -> std::same_as; +}; +// clang-format on + +template +struct SqlDataBinder +{ + using ValueType = StringType; + using StringTraits = SqlOutputStringTraits; + + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, ValueType const& value) noexcept + { + return SQLBindParameter(stmt, + column, + SQL_PARAM_INPUT, + SQL_C_CHAR, + SQL_VARCHAR, + StringTraits::Size(&value), + 0, + (SQLPOINTER) StringTraits::Data(&value), + 0, + nullptr); + } + + static SQLRETURN OutputColumn( + SQLHSTMT stmt, SQLUSMALLINT column, ValueType* result, SQLLEN* indicator, SqlDataBinderCallback& cb) noexcept + { + // Ensure we're having sufficient space to store the worst-case scenario of bytes in this column + SQLULEN columnSize {}; + auto const describeResult = SQLDescribeCol(stmt, + column, + nullptr /*colName*/, + 0 /*sizeof(colName)*/, + nullptr /*&colNameLen*/, + nullptr /*&dataType*/, + &columnSize, + nullptr /*&decimalDigits*/, + nullptr /*&nullable*/); + if (!SQL_SUCCEEDED(describeResult)) + return describeResult; + + StringTraits::Reserve(result, + columnSize); // Must be called now, because otherwise std::string won't do anything + + cb.PlanPostProcessOutputColumn([indicator, result]() { + // Now resize the string to the actual length of the data + // NB: If the indicator is greater than the buffer size, we have a truncation. + auto const bufferSize = StringTraits::Size(result); + auto const len = std::cmp_greater_equal(*indicator, bufferSize) || *indicator == SQL_NO_TOTAL + ? bufferSize - 1 + : *indicator; + StringTraits::Resize(result, len); + }); + return SQLBindCol(stmt, + column, + SQL_C_CHAR, + (SQLPOINTER) StringTraits::Data(result), + (SQLLEN) StringTraits::Size(result), + indicator); + } + + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLUSMALLINT column, ValueType* result, SQLLEN* indicator) noexcept + { + StringTraits::Reserve(result, 15); + size_t writeIndex = 0; + *indicator = 0; + while (true) + { + char* const bufferStart = StringTraits::Data(result) + writeIndex; + size_t const bufferSize = StringTraits::Size(result) - writeIndex; + SQLRETURN const rv = SQLGetData(stmt, column, SQL_C_CHAR, bufferStart, bufferSize, indicator); + switch (rv) + { + case SQL_SUCCESS: + case SQL_NO_DATA: + // last successive call + StringTraits::Resize(result, writeIndex + *indicator); + *indicator = StringTraits::Size(result); + return SQL_SUCCESS; + case SQL_SUCCESS_WITH_INFO: { + // more data pending + if (*indicator == SQL_NO_TOTAL) + { + // We have a truncation and the server does not know how much data is left. + writeIndex += bufferSize - 1; + StringTraits::Resize(result, (2 * writeIndex) + 1); + } + else if (std::cmp_greater_equal(*indicator, bufferSize)) + { + // We have a truncation and the server knows how much data is left. + writeIndex += bufferSize - 1; + StringTraits::Resize(result, writeIndex + *indicator); + } + else + { + // We have no truncation and the server knows how much data is left. + StringTraits::Resize(result, writeIndex + *indicator - 1); + return SQL_SUCCESS; + } + break; + } + default: + return rv; + } + } + } +}; + +template +struct SqlDataBinder> +{ + using ValueType = SqlFixedString; + using StringTraits = SqlOutputStringTraits; + + static void TrimRight(ValueType* boundOutputString, SQLLEN indicator) noexcept + { + size_t n = indicator; + while (n > 0 && std::isspace((*boundOutputString)[n - 1])) + --n; + boundOutputString->setsize(n); + } + + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, ValueType const& value) noexcept + { + return SQLBindParameter(stmt, + column, + SQL_PARAM_INPUT, + SQL_C_CHAR, + SQL_VARCHAR, + value.size(), + 0, + (SQLPOINTER) value.c_str(), // Ensure Null-termination. + sizeof(value), + nullptr); + } + + static SQLRETURN OutputColumn( + SQLHSTMT stmt, SQLUSMALLINT column, ValueType* result, SQLLEN* indicator, SqlDataBinderCallback& cb) noexcept + { + if constexpr (PostOp == SqlStringPostRetrieveOperation::TRIM_RIGHT) + { + ValueType* boundOutputString = result; + cb.PlanPostProcessOutputColumn([indicator, boundOutputString]() { + // NB: If the indicator is greater than the buffer size, we have a truncation. + auto const len = + std::cmp_greater_equal(*indicator, N + 1) || *indicator == SQL_NO_TOTAL ? N : *indicator; + if constexpr (PostOp == SqlStringPostRetrieveOperation::TRIM_RIGHT) + TrimRight(boundOutputString, len); + else + boundOutputString->setsize(len); + }); + } + return SQLBindCol( + stmt, column, SQL_C_CHAR, (SQLPOINTER) result->data(), (SQLLEN) result->capacity(), indicator); + } + + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLUSMALLINT column, ValueType* result, SQLLEN* indicator) noexcept + { + *indicator = 0; + const SQLRETURN rv = SQLGetData(stmt, column, SQL_C_CHAR, result->data(), result->capacity(), indicator); + switch (rv) + { + case SQL_SUCCESS: + case SQL_NO_DATA: + // last successive call + result->setsize(*indicator); + if constexpr (PostOp == SqlStringPostRetrieveOperation::TRIM_RIGHT) + TrimRight(result, *indicator); + return SQL_SUCCESS; + case SQL_SUCCESS_WITH_INFO: { + // more data pending + // Truncating. This case should never happen. + result->setsize(result->capacity() - 1); + if constexpr (PostOp == SqlStringPostRetrieveOperation::TRIM_RIGHT) + TrimRight(result, *indicator); + return SQL_SUCCESS; + } + default: + return rv; + } + } +}; + +template +struct SqlDataBinder +{ + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, char const* value) noexcept + { + static_assert(N > 0, "N must be greater than 0"); // I cannot imagine that N is 0, ever. + return SQLBindParameter( + stmt, column, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, N - 1, 0, (SQLPOINTER) value, 0, nullptr); + } +}; + +template <> +struct SqlDataBinder +{ + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, std::string_view value) noexcept + { + return SQLBindParameter(stmt, + column, + SQL_PARAM_INPUT, + SQL_C_CHAR, + SQL_VARCHAR, + value.size(), + 0, + (SQLPOINTER) value.data(), + 0, + nullptr); + } +}; + +template <> +struct SqlDataBinder +{ + using InnerStringType = decltype(std::declval().value); + using StringTraits = SqlOutputStringTraits; + + static void TrimRight(InnerStringType* boundOutputString, SQLLEN indicator) noexcept + { + size_t n = indicator; + while (n > 0 && std::isspace((*boundOutputString)[n - 1])) + --n; + StringTraits::Resize(boundOutputString, static_cast(n)); + } + + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, SqlTrimmedString const& value) noexcept + { + return SqlDataBinder::InputParameter(stmt, column, value.value); + } + + static SQLRETURN OutputColumn(SQLHSTMT stmt, + SQLUSMALLINT column, + SqlTrimmedString* result, + SQLLEN* indicator, + SqlDataBinderCallback& cb) noexcept + { + auto* boundOutputString = &result->value; + cb.PlanPostProcessOutputColumn([indicator, boundOutputString]() { + // NB: If the indicator is greater than the buffer size, we have a truncation. + auto const bufferSize = StringTraits::Size(boundOutputString); + auto const len = std::cmp_greater_equal(*indicator, bufferSize) || *indicator == SQL_NO_TOTAL + ? bufferSize - 1 + : *indicator; + TrimRight(boundOutputString, static_cast(len)); + }); + return SQLBindCol(stmt, + column, + SQL_C_CHAR, + (SQLPOINTER) StringTraits::Data(boundOutputString), + (SQLLEN) StringTraits::Size(boundOutputString), + indicator); + } + + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLUSMALLINT column, SqlTrimmedString* result, SQLLEN* indicator) noexcept + { + auto const returnCode = SqlDataBinder::GetColumn(stmt, column, &result->value, indicator); + TrimRight(&result->value, *indicator); + return returnCode; + } +}; + +template <> +struct SqlDataBinder +{ + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, SqlDate const& value) noexcept + { + return SQLBindParameter(stmt, + column, + SQL_PARAM_INPUT, + SQL_C_TYPE_DATE, + SQL_TYPE_DATE, + 0, + 0, + (SQLPOINTER) &value.sqlValue, + 0, + nullptr); + } + + static SQLRETURN OutputColumn(SQLHSTMT stmt, + SQLUSMALLINT column, + SqlDate* result, + SQLLEN* /*indicator*/, + SqlDataBinderCallback& /*cb*/) noexcept + { + // TODO: handle indicator to check for NULL values + return SQLBindCol(stmt, column, SQL_C_TYPE_DATE, &result->sqlValue, sizeof(result->sqlValue), nullptr); + } + + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLUSMALLINT column, SqlDate* result, SQLLEN* indicator) noexcept + { + return SQLGetData(stmt, column, SQL_C_TYPE_DATE, &result->sqlValue, sizeof(result->sqlValue), indicator); + } +}; + +template <> +struct SqlDataBinder +{ + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, SqlTime const& value) noexcept + { + return SQLBindParameter(stmt, + column, + SQL_PARAM_INPUT, + SQL_C_TYPE_TIME, + SQL_TYPE_TIME, + 0, + 0, + (SQLPOINTER) &value.sqlValue, + 0, + nullptr); + } + + static SQLRETURN OutputColumn(SQLHSTMT stmt, + SQLUSMALLINT column, + SqlTime* result, + SQLLEN* /*indicator*/, + SqlDataBinderCallback& /*cb*/) noexcept + { + // TODO: handle indicator to check for NULL values + return SQLBindCol(stmt, column, SQL_C_TYPE_TIME, &result->sqlValue, sizeof(result->sqlValue), nullptr); + } + + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLUSMALLINT column, SqlTime* result, SQLLEN* indicator) noexcept + { + return SQLGetData(stmt, column, SQL_C_TYPE_TIME, &result->sqlValue, sizeof(result->sqlValue), indicator); + } +}; + +template <> +struct SqlDataBinder +{ + static SQLRETURN GetColumn(SQLHSTMT stmt, + SQLUSMALLINT column, + SqlDateTime::native_type* result, + SQLLEN* indicator) noexcept + { + SQL_TIMESTAMP_STRUCT sqlValue {}; + auto const rc = SQLGetData(stmt, column, SQL_C_TYPE_TIMESTAMP, &sqlValue, sizeof(sqlValue), indicator); + if (SQL_SUCCEEDED(rc)) + *result = SqlDateTime::ConvertToNative(sqlValue); + return rc; + } +}; + +template <> +struct SqlDataBinder +{ + static SQLRETURN InputParameter(SQLHSTMT stmt, SQLUSMALLINT column, SqlDateTime const& value) noexcept + { + return SQLBindParameter(stmt, + column, + SQL_PARAM_INPUT, + SQL_C_TIMESTAMP, + SQL_TYPE_TIMESTAMP, + 27, + 7, + (SQLPOINTER) &value.sqlValue, + sizeof(value), + nullptr); + } + + static SQLRETURN OutputColumn(SQLHSTMT stmt, + SQLUSMALLINT column, + SqlDateTime* result, + SQLLEN* indicator, + SqlDataBinderCallback& /*cb*/) noexcept + { + // TODO: handle indicator to check for NULL values + *indicator = sizeof(result->sqlValue); + return SQLBindCol(stmt, column, SQL_C_TYPE_TIMESTAMP, &result->sqlValue, 0, indicator); + } + + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLUSMALLINT column, SqlDateTime* result, SQLLEN* indicator) noexcept + { + return SQLGetData(stmt, column, SQL_C_TYPE_TIMESTAMP, &result->sqlValue, sizeof(result->sqlValue), indicator); + } +}; + +template <> +struct SqlDataBinder +{ + static SQLRETURN GetColumn(SQLHSTMT stmt, SQLUSMALLINT column, SqlVariant* result, SQLLEN* indicator) noexcept + { + SQLLEN columnType {}; + SQLRETURN returnCode = + SQLColAttributeA(stmt, static_cast(column), SQL_DESC_TYPE, nullptr, 0, nullptr, &columnType); + if (!SQL_SUCCEEDED(returnCode)) + return returnCode; + + switch (columnType) + { + case SQL_BIT: + result->emplace(); + returnCode = SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_TINYINT: + result->emplace(); + returnCode = SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_SMALLINT: + result->emplace(); + returnCode = SqlDataBinder::GetColumn( + stmt, column, &std::get(*result), indicator); + break; + case SQL_INTEGER: + result->emplace(); + returnCode = SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_BIGINT: + result->emplace(); + returnCode = + SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_REAL: + result->emplace(); + returnCode = SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_FLOAT: + case SQL_DOUBLE: + result->emplace(); + returnCode = SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_CHAR: // fixed-length string + case SQL_VARCHAR: // variable-length string + case SQL_LONGVARCHAR: // long string + case SQL_WCHAR: // fixed-length Unicode (UTF-16) string + case SQL_WVARCHAR: // variable-length Unicode (UTF-16) string + case SQL_WLONGVARCHAR: // long Unicode (UTF-16) string + case SQL_BINARY: // fixed-length binary + case SQL_VARBINARY: // variable-length binary + case SQL_LONGVARBINARY: // long binary + result->emplace(); + returnCode = + SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_DATE: + SqlLogger::GetLogger().OnWarning( + std::format("SQL_DATE is from ODBC 2. SQL_TYPE_DATE should have been received instead.")); + [[fallthrough]]; + case SQL_TYPE_DATE: + result->emplace(); + returnCode = SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_TIME: + SqlLogger::GetLogger().OnWarning( + std::format("SQL_TIME is from ODBC 2. SQL_TYPE_TIME should have been received instead.")); + [[fallthrough]]; + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + result->emplace(); + returnCode = SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_TYPE_TIMESTAMP: + result->emplace(); + returnCode = + SqlDataBinder::GetColumn(stmt, column, &std::get(*result), indicator); + break; + case SQL_TYPE_NULL: + case SQL_DECIMAL: + case SQL_NUMERIC: + case SQL_GUID: + // TODO: Get them implemented on demand + [[fallthrough]]; + default: + std::println("Unsupported column type: {}", columnType); + SqlLogger::GetLogger().OnError(SqlError::UNSUPPORTED_TYPE, SqlErrorInfo::fromStatementHandle(stmt)); + returnCode = SQL_ERROR; // std::errc::invalid_argument; + } + if (*indicator == SQL_NULL_DATA) + *result = std::monostate {}; + return returnCode; + } +}; + +template +concept SqlInputParameterBinder = requires(SQLHSTMT hStmt, SQLUSMALLINT column, T const& value) { + { SqlDataBinder::InputParameter(hStmt, column, value) } -> std::same_as; +}; + +template +concept SqlOutputColumnBinder = + requires(SQLHSTMT hStmt, SQLUSMALLINT column, T* result, SQLLEN* indicator, SqlDataBinderCallback& cb) { + { SqlDataBinder::OutputColumn(hStmt, column, result, indicator, cb) } -> std::same_as; + }; + +template +concept SqlInputParameterBatchBinder = + requires(SQLHSTMT hStmt, SQLUSMALLINT column, std::ranges::range_value_t* result) { + { + SqlDataBinder>::InputParameter( + hStmt, column, std::declval>()) + } -> std::same_as; + }; + +template +concept SqlGetColumnNativeType = requires(SQLHSTMT hStmt, SQLUSMALLINT column, T* result, SQLLEN* indicator) { + { SqlDataBinder::GetColumn(hStmt, column, result, indicator) } -> std::same_as; +}; diff --git a/src/Lightweight/SqlError.cpp b/src/Lightweight/SqlError.cpp new file mode 100644 index 00000000..5de26f05 --- /dev/null +++ b/src/Lightweight/SqlError.cpp @@ -0,0 +1,9 @@ +#include "SqlError.hpp" + +void SqlErrorInfo::RequireStatementSuccess(SQLRETURN result, SQLHSTMT hStmt, std::string_view message) +{ + if (result == SQL_SUCCESS || result == SQL_SUCCESS_WITH_INFO) [[likely]] + return; + + throw std::runtime_error { std::format("{}: {}", message, fromStatementHandle(hStmt)) }; +} diff --git a/src/Lightweight/SqlError.hpp b/src/Lightweight/SqlError.hpp new file mode 100644 index 00000000..6380111e --- /dev/null +++ b/src/Lightweight/SqlError.hpp @@ -0,0 +1,178 @@ +#pragma once + +#if defined(_WIN32) || defined(_WIN64) + #include +#endif + +#include +#include +#include + +#include +#include +#include +#include + +// NOTE: This is a simple wrapper around the SQL return codes. It is not meant to be +// comprehensive, but rather to provide a simple way to convert SQL return codes to +// std::error_code. +// +// The code below is DRAFT and may be subject to change. + +struct SqlErrorInfo +{ + SQLINTEGER nativeErrorCode {}; + std::string sqlState = " "; // 5 characters + null terminator + std::string message; + + static SqlErrorInfo fromConnectionHandle(SQLHDBC hDbc) + { + return fromHandle(SQL_HANDLE_DBC, hDbc); + } + + static SqlErrorInfo fromStatementHandle(SQLHSTMT hStmt) + { + return fromHandle(SQL_HANDLE_STMT, hStmt); + } + + static SqlErrorInfo fromHandle(SQLSMALLINT handleType, SQLHANDLE handle) + { + SqlErrorInfo info {}; + info.message = std::string(1024, '\0'); + + SQLSMALLINT msgLen {}; + SQLGetDiagRecA(handleType, + handle, + 1, + (SQLCHAR*) info.sqlState.data(), + &info.nativeErrorCode, + (SQLCHAR*) info.message.data(), + (SQLSMALLINT) info.message.capacity(), + &msgLen); + info.message.resize(msgLen); + return info; + } + + static void RequireStatementSuccess(SQLRETURN result, SQLHSTMT hStmt, std::string_view message); +}; + +enum class SqlError : std::int16_t +{ + SUCCESS = SQL_SUCCESS, + SUCCESS_WITH_INFO = SQL_SUCCESS_WITH_INFO, + NODATA = SQL_NO_DATA, + FAILURE = SQL_ERROR, + INVALID_HANDLE = SQL_INVALID_HANDLE, + STILL_EXECUTING = SQL_STILL_EXECUTING, + NEED_DATA = SQL_NEED_DATA, + PARAM_DATA_AVAILABLE = SQL_PARAM_DATA_AVAILABLE, + NO_DATA_FOUND = SQL_NO_DATA_FOUND, + UNSUPPORTED_TYPE = 1'000, + INVALID_ARGUMENT = 1'001, +}; + +struct SqlErrorCategory: std::error_category +{ + static SqlErrorCategory const& get() noexcept + { + static SqlErrorCategory const category; + return category; + } + [[nodiscard]] const char* name() const noexcept override + { + return "Lightweight"; + } + + [[nodiscard]] std::string message(int code) const override + { + using namespace std::string_literals; + switch (static_cast(code)) + { + case SqlError::SUCCESS: + return "SQL_SUCCESS"s; + case SqlError::SUCCESS_WITH_INFO: + return "SQL_SUCCESS_WITH_INFO"s; + case SqlError::NODATA: + return "SQL_NO_DATA"s; + case SqlError::FAILURE: + return "SQL_ERROR"s; + case SqlError::INVALID_HANDLE: + return "SQL_INVALID_HANDLE"s; + case SqlError::STILL_EXECUTING: + return "SQL_STILL_EXECUTING"s; + case SqlError::NEED_DATA: + return "SQL_NEED_DATA"s; + case SqlError::PARAM_DATA_AVAILABLE: + return "SQL_PARAM_DATA_AVAILABLE"s; + case SqlError::UNSUPPORTED_TYPE: + return "SQL_UNSUPPORTED_TYPE"s; + case SqlError::INVALID_ARGUMENT: + return "SQL_INVALID_ARGUMENT"s; + } + return std::format("SQL error code {}", code); + } +}; + +// Register our enum as an error code so we can constructor error_code from it +template <> +struct std::is_error_code_enum: public std::true_type +{ +}; + +// Tells the compiler that MyErr pairs with MyCategory +inline std::error_code make_error_code(SqlError e) +{ + return { static_cast(e), SqlErrorCategory::get() }; +} + +// Represents the result of a call to the SQL server. +template +using SqlResult = std::expected; + +namespace detail +{ + +inline SqlResult UpdateSqlError(SqlError* errorCode, SQLRETURN error) noexcept +{ + if (error == SQL_SUCCESS || error == SQL_SUCCESS_WITH_INFO) + { + *errorCode = SqlError::SUCCESS; + return {}; + } + + *errorCode = static_cast(error); + return std::unexpected { *errorCode }; +} + +} // namespace detail + +template <> +struct std::formatter: formatter +{ + auto format(SqlError value, format_context& ctx) const -> format_context::iterator + { + return formatter::format(std::format("{}", SqlErrorCategory().message(static_cast(value))), + ctx); + } +}; + +template <> +struct std::formatter: formatter +{ + auto format(SqlErrorInfo const& info, format_context& ctx) const -> format_context::iterator + { + return formatter::format( + std::format("{} ({}) - {}", info.sqlState, info.nativeErrorCode, info.message), ctx); + } +}; + +template +struct std::formatter>: std::formatter +{ + auto format(SqlResult const& result, format_context& ctx) -> format_context::iterator + { + if (result) + return std::formatter::format(std::format("{}", result.value()), ctx); + return std::formatter::format(std::format("{}", result.error()), ctx); + } +}; diff --git a/src/Lightweight/SqlLogger.cpp b/src/Lightweight/SqlLogger.cpp new file mode 100644 index 00000000..6874d18f --- /dev/null +++ b/src/Lightweight/SqlLogger.cpp @@ -0,0 +1,185 @@ +#include "SqlConnectInfo.hpp" +#include "SqlConnection.hpp" +#include "SqlLogger.hpp" + +#include +#include +#include +#include +#include +#include + +#if __has_include() + #include +#endif + +#if defined(_MSC_VER) + // Disable warning C4996: This function or variable may be unsafe. + // It is complaining about getenv, which is fine to use in this case. + #pragma warning(disable : 4996) +#endif + +namespace +{ + +class SqlStandardLogger: public SqlLogger +{ + private: + std::chrono::time_point m_currentTime {}; + std::string m_currentTimeStr {}; + + public: + void Tick() + { + m_currentTime = std::chrono::system_clock::now(); + auto const nowMs = time_point_cast(m_currentTime); + m_currentTimeStr = std::format("{:%F %X}.{:03}", m_currentTime, nowMs.time_since_epoch().count() % 1'000); + } + + template + void WriteMessage(std::format_string const& fmt, Args&&... args) + { + // TODO: Use the new logging mechanism from Felix here, once merged. + std::println("[{}] {}", m_currentTimeStr, std::format(fmt, std::forward(args)...)); + } + + void OnWarning(std::string_view const& message) override + { + Tick(); + WriteMessage("Warning: {}", message); + } + + void OnError(SqlError errorCode, SqlErrorInfo const& errorInfo, std::source_location /*sourceLocation*/) override + { + Tick(); + WriteMessage("Error: {}", SqlErrorCategory().message((int) errorCode)); + WriteMessage(" SQLSTATE: {}", errorInfo.sqlState); + WriteMessage(" Native error code: {}", errorInfo.nativeErrorCode); + WriteMessage(" Message: {}", errorInfo.message); + } + + void OnConnectionOpened(SqlConnection const&) override {} + void OnConnectionClosed(SqlConnection const&) override {} + void OnConnectionIdle(SqlConnection const&) override {} + void OnConnectionReuse(SqlConnection const&) override {} + void OnExecuteDirect(std::string_view const&) override {} + void OnPrepare(std::string_view const&) override {} + void OnExecute() override {} + void OnExecuteBatch() override {} + void OnFetchedRow() override {} + void OnStats(SqlConnectionStats const&) override {} +}; + +class SqlTraceLogger: public SqlStandardLogger +{ + std::string m_lastPreparedQuery; + + public: + void OnError(SqlError errorCode, SqlErrorInfo const& errorInfo, std::source_location sourceLocation) override + { + SqlStandardLogger::OnError(errorCode, errorInfo, sourceLocation); + + WriteMessage(" Source: {}:{}", sourceLocation.file_name(), sourceLocation.line()); + if (!m_lastPreparedQuery.empty()) + WriteMessage(" Query: {}", m_lastPreparedQuery); + WriteMessage(" Stack trace:"); + +#if __has_include() + auto stackTrace = std::stacktrace::current(1, 25); + for (std::size_t const i: std::views::iota(std::size_t(0), stackTrace.size())) + WriteMessage(" [{:>2}] {}", i, stackTrace[i]); +#endif + } + + void OnConnectionOpened(SqlConnection const& connection) override + { + Tick(); + WriteMessage("Connection {} opened: {}", connection.ConnectionId(), connection.ConnectionInfo()); + } + + void OnConnectionClosed(SqlConnection const& connection) override + { + Tick(); + WriteMessage("Connection {} closed: {}", connection.ConnectionId(), connection.ConnectionInfo()); + } + + void OnConnectionIdle(SqlConnection const& /*connection*/) override + { + // Tick(); + // WriteMessage("Connection {} idle: {}", connection.ConnectionId(), connection.ConnectionInfo()); + } + + void OnConnectionReuse(SqlConnection const& /*connection*/) override + { + // Tick(); + // WriteMessage("Connection {} reused: {}", connection.ConnectionId(), connection.ConnectionInfo()); + } + + void OnExecuteDirect(std::string_view const& query) override + { + Tick(); + WriteMessage("ExecuteDirect: {}", query); + } + + void OnPrepare(std::string_view const& query) override + { + m_lastPreparedQuery = query; + } + + void OnExecute() override + { + Tick(); + WriteMessage("Execute: {}", m_lastPreparedQuery); + m_lastPreparedQuery.clear(); + } + + void OnExecuteBatch() override + { + Tick(); + WriteMessage("ExecuteBatch: {}", m_lastPreparedQuery); + } + + void OnFetchedRow() override + { + Tick(); + WriteMessage("Fetched row"); + } + + void OnStats(SqlConnectionStats const& stats) override + { + Tick(); + WriteMessage("[SqlConnectionPool stats] " + "created: {}, reused: {}, closed: {}, timedout: {}, released: {}", + stats.created, + stats.reused, + stats.closed, + stats.timedout, + stats.released); + } +}; + +} // namespace + +static SqlStandardLogger theStdLogger {}; +SqlLogger& SqlLogger::StandardLogger() +{ + return theStdLogger; +} + +SqlLogger& SqlLogger::TraceLogger() +{ + static SqlTraceLogger logger {}; + return logger; +} + +static SqlLogger* theDefaultLogger = &SqlLogger::StandardLogger(); + +SqlLogger& SqlLogger::GetLogger() +{ + return *theDefaultLogger; +} + +void SqlLogger::SetLogger(SqlLogger& logger) +{ + theDefaultLogger = &logger; +} diff --git a/src/Lightweight/SqlLogger.hpp b/src/Lightweight/SqlLogger.hpp new file mode 100644 index 00000000..4ebab617 --- /dev/null +++ b/src/Lightweight/SqlLogger.hpp @@ -0,0 +1,58 @@ +#pragma once + +#include "SqlConnectInfo.hpp" +#include "SqlError.hpp" + +#include +#include +#include + +class SqlConnection; + +struct SqlConnectionStats +{ + size_t created {}; + size_t reused {}; + size_t closed {}; + size_t timedout {}; + size_t released {}; +}; + +// Represents a logger for SQL operations. +class SqlLogger +{ + public: + virtual ~SqlLogger() = default; + + // Logs a warning message. + virtual void OnWarning(std::string_view const& message) = 0; + + // An ODBC SQL error occurred. + virtual void OnError(SqlError errorCode, + SqlErrorInfo const& errorInfo, + std::source_location sourceLocation = std::source_location::current()) = 0; + + virtual void OnConnectionOpened(SqlConnection const& connection) = 0; + virtual void OnConnectionClosed(SqlConnection const& connection) = 0; + virtual void OnConnectionIdle(SqlConnection const& connection) = 0; + virtual void OnConnectionReuse(SqlConnection const& connection) = 0; + virtual void OnExecuteDirect(std::string_view const& query) = 0; + virtual void OnPrepare(std::string_view const& query) = 0; + virtual void OnExecute() = 0; + virtual void OnExecuteBatch() = 0; + virtual void OnFetchedRow() = 0; + virtual void OnStats(SqlConnectionStats const& stats) = 0; + + // Logs the most important events to standard output in a human-readable format. + static SqlLogger& StandardLogger(); + + // Logs every little event to standard output in a human-readable compact format. + static SqlLogger& TraceLogger(); + + // Retrieves the current logger. + static SqlLogger& GetLogger(); + + // Sets the current logger. + // The ownership of the logger is not transferred and remains with the caller. + static void SetLogger(SqlLogger& logger); +}; diff --git a/src/Lightweight/SqlQueryFormatter.cpp b/src/Lightweight/SqlQueryFormatter.cpp new file mode 100644 index 00000000..eff79ed0 --- /dev/null +++ b/src/Lightweight/SqlQueryFormatter.cpp @@ -0,0 +1,188 @@ +#include "SqlComposedQuery.hpp" +#include "SqlQueryFormatter.hpp" + +#include +#include + +using namespace std::string_view_literals; + +namespace +{ + +class BasicSqlQueryFormatter: public SqlQueryFormatter +{ + public: + [[nodiscard]] std::string BooleanWhereClause(SqlQualifiedTableColumnName const& column, + std::string_view op, + bool literalValue) const override + { + auto const literalValueStr = literalValue ? "TRUE"sv : "FALSE"sv; + if (!column.tableName.empty()) + return std::format(R"("{}"."{}" {} {})", column.tableName, column.columnName, op, literalValueStr); + else + return std::format(R"("{}" {} {})", column.columnName, op, literalValueStr); + } + + [[nodiscard]] std::string SelectCount(std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition) const override + { + return std::format("SELECT COUNT(*) FROM \"{}\"{}{}", fromTable, tableJoins, whereCondition); + } + + [[nodiscard]] std::string SelectAll(std::string const& fields, + std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition, + std::string const& orderBy, + std::string const& groupBy) const override + { + const auto* const delimiter = tableJoins.empty() ? "" : "\n "; + // clang-format off + std::stringstream sqlQueryString; + sqlQueryString << "SELECT " << fields + << delimiter << " FROM \"" << fromTable << "\"" + << tableJoins + << delimiter << whereCondition + << delimiter << groupBy + << delimiter << orderBy; + return sqlQueryString.str(); + // clang-format on + } + + [[nodiscard]] std::string SelectFirst(std::string const& fields, + std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition, + std::string const& orderBy, + size_t count) const override + { + // clang-format off + std::stringstream sqlQueryString; + sqlQueryString << "SELECT " << fields + << " FROM \"" << fromTable << "\"" + << tableJoins + << whereCondition + << orderBy + << " LIMIT " << count; + return sqlQueryString.str(); + // clang-format on + } + + [[nodiscard]] std::string SelectRange(std::string const& fields, + std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition, + std::string const& orderBy, + std::string const& groupBy, + std::size_t offset, + std::size_t limit) const override + { + // clang-format off + std::stringstream sqlQueryString; + sqlQueryString << "SELECT " << fields + << " FROM \"" << fromTable << "\"" + << tableJoins + << whereCondition + << groupBy + << orderBy + << " LIMIT " << limit << " OFFSET " << offset; + return sqlQueryString.str(); + // clang-format on + } +}; + +class SqlServerQueryFormatter final: public BasicSqlQueryFormatter +{ + public: + std::string BooleanWhereClause(SqlQualifiedTableColumnName const& column, + std::string_view op, + bool literalValue) const override + { + auto const literalValueStr = literalValue ? '1' : '0'; + if (!column.tableName.empty()) + return std::format(R"("{}"."{}" {} {})", column.columnName, column.columnName, op, literalValueStr); + else + return std::format(R"("{}" {} {})", column.columnName, op, literalValueStr); + } + + std::string SelectFirst(std::string const& fields, + std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition, + std::string const& orderBy, + size_t count) const override + { + // clang-format off + std::stringstream sqlQueryString; + sqlQueryString << "SELECT TOP " << count << " " + << fields + << " FROM \"" << fromTable << "\"" + << tableJoins + << whereCondition + << orderBy; + return sqlQueryString.str(); + // clang-format on + } + + std::string SelectRange(std::string const& fields, + std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition, + std::string const& orderBy, + std::string const& groupBy, + std::size_t offset, + std::size_t limit) const override + { + assert(!orderBy.empty()); + // clang-format off + std::stringstream sqlQueryString; + sqlQueryString << "SELECT " << fields + << " FROM \"" << fromTable << "\"" + << tableJoins + << whereCondition + << groupBy + << orderBy + << " OFFSET " << offset << " ROWS FETCH NEXT " << limit << " ROWS ONLY"; + return sqlQueryString.str(); + // clang-format on + } +}; + +} // namespace + +SqlQueryFormatter const& SqlQueryFormatter::Sqlite() +{ + static const BasicSqlQueryFormatter formatter {}; + return formatter; +} + +SqlQueryFormatter const& SqlQueryFormatter::SqlServer() +{ + static const SqlServerQueryFormatter formatter {}; + return formatter; +} + +SqlQueryFormatter const& SqlQueryFormatter::PostgrSQL() +{ + static const BasicSqlQueryFormatter formatter {}; + return formatter; +} + +SqlQueryFormatter const* SqlQueryFormatter::Get(SqlServerType serverType) noexcept +{ + switch (serverType) + { + case SqlServerType::SQLITE: + return &Sqlite(); + case SqlServerType::MICROSOFT_SQL: + return &SqlServer(); + case SqlServerType::POSTGRESQL: + return &PostgrSQL(); + case SqlServerType::ORACLE: // TODO + case SqlServerType::MYSQL: // TODO + case SqlServerType::UNKNOWN: + break; + } + return nullptr; +} diff --git a/src/Lightweight/SqlQueryFormatter.hpp b/src/Lightweight/SqlQueryFormatter.hpp new file mode 100644 index 00000000..12619aa2 --- /dev/null +++ b/src/Lightweight/SqlQueryFormatter.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include "SqlConnection.hpp" +#include "SqlDataBinder.hpp" + +#include +#include + +struct SqlQualifiedTableColumnName; + +// API to format SQL queries for different SQL dialects. +class SqlQueryFormatter +{ + public: + virtual ~SqlQueryFormatter() = default; + + [[nodiscard]] virtual std::string BooleanWhereClause(SqlQualifiedTableColumnName const& column, + std::string_view op, + bool literalValue) const = 0; + + [[nodiscard]] virtual std::string SelectAll(std::string const& fields, + std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition, + std::string const& orderBy, + std::string const& groupBy) const = 0; + + [[nodiscard]] virtual std::string SelectFirst(std::string const& fields, + std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition, + std::string const& orderBy, + size_t count) const = 0; + + [[nodiscard]] virtual std::string SelectRange(std::string const& fields, + std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition, + std::string const& orderBy, + std::string const& groupBy, + std::size_t offset, + std::size_t limit) const = 0; + + [[nodiscard]] virtual std::string SelectCount(std::string const& fromTable, + std::string const& tableJoins, + std::string const& whereCondition) const = 0; + + static SqlQueryFormatter const& Sqlite(); + static SqlQueryFormatter const& SqlServer(); + static SqlQueryFormatter const& PostgrSQL(); + + static SqlQueryFormatter const* Get(SqlServerType serverType) noexcept; +}; diff --git a/src/Lightweight/SqlSchema.cpp b/src/Lightweight/SqlSchema.cpp new file mode 100644 index 00000000..5142b04c --- /dev/null +++ b/src/Lightweight/SqlSchema.cpp @@ -0,0 +1,350 @@ +#include "SqlConnection.hpp" +#include "SqlSchema.hpp" +#include "SqlStatement.hpp" + +#include + +#include +#include +#include +#include + +namespace SqlSchema +{ + +using namespace std::string_literals; +using namespace std::string_view_literals; + +using KeyPair = std::pair; + +bool operator<(KeyPair const& a, KeyPair const& b) +{ + return std::tie(a.first, a.second) < std::tie(b.first, b.second); +} + +namespace +{ + SqlColumnType FromNativeDataType(int value) + { + switch (value) + { + case SQL_UNKNOWN_TYPE: + return SqlColumnType::UNKNOWN; + case SQL_CHAR: + case SQL_WCHAR: + return SqlColumnType::CHAR; + case SQL_VARCHAR: + case SQL_WVARCHAR: + return SqlColumnType::STRING; + case SQL_LONGVARCHAR: + case SQL_WLONGVARCHAR: + return SqlColumnType::TEXT; + case SQL_BIT: + return SqlColumnType::BOOLEAN; + case SQL_TINYINT: + return SqlColumnType::INTEGER; + case SQL_SMALLINT: + return SqlColumnType::INTEGER; + case SQL_INTEGER: + return SqlColumnType::INTEGER; + case SQL_BIGINT: + return SqlColumnType::INTEGER; + case SQL_REAL: + return SqlColumnType::REAL; + case SQL_FLOAT: + return SqlColumnType::REAL; + case SQL_DOUBLE: + return SqlColumnType::REAL; + case SQL_TYPE_DATE: + return SqlColumnType::DATE; + case SQL_TYPE_TIME: + return SqlColumnType::TIME; + case SQL_TYPE_TIMESTAMP: + return SqlColumnType::DATETIME; + default: + std::println("Unknown SQL type {}", value); + return SqlColumnType::UNKNOWN; + } + } + + std::vector AllTables(std::string_view database, std::string_view schema) + { + auto const tableType = "TABLE"sv; + (void) database; + (void) schema; + + auto stmt = SqlStatement(); + auto sqlResult = SQLTables(stmt.NativeHandle(), + (SQLCHAR*) database.data(), + (SQLSMALLINT) database.size(), + (SQLCHAR*) schema.data(), + (SQLSMALLINT) schema.size(), + nullptr /* tables */, + 0 /* tables length */, + (SQLCHAR*) tableType.data(), + (SQLSMALLINT) tableType.size()); + SqlErrorInfo::RequireStatementSuccess(sqlResult, stmt.NativeHandle(), "SQLTables"); + + auto result = std::vector(); + while (stmt.FetchRow()) + result.emplace_back(stmt.GetColumn(3)); + + return result; + } + + std::vector AllForeignKeys(FullyQualifiedTableName const& primaryKey, + FullyQualifiedTableName const& foreignKey) + { + auto stmt = SqlStatement(); + auto sqlResult = SQLForeignKeys(stmt.NativeHandle(), + (SQLCHAR*) primaryKey.catalog.data(), + (SQLSMALLINT) primaryKey.catalog.size(), + (SQLCHAR*) primaryKey.schema.data(), + (SQLSMALLINT) primaryKey.schema.size(), + (SQLCHAR*) primaryKey.table.data(), + (SQLSMALLINT) primaryKey.table.size(), + (SQLCHAR*) foreignKey.catalog.data(), + (SQLSMALLINT) foreignKey.catalog.size(), + (SQLCHAR*) foreignKey.schema.data(), + (SQLSMALLINT) foreignKey.schema.size(), + (SQLCHAR*) foreignKey.table.data(), + (SQLSMALLINT) foreignKey.table.size()); + + if (!SQL_SUCCEEDED(sqlResult)) + throw std::runtime_error( + std::format("SQLForeignKeys failed: {}", SqlErrorInfo::fromStatementHandle(stmt.NativeHandle()))); + + using ColumnList = std::vector; + auto constraints = std::map(); + while (stmt.FetchRow()) + { + auto primaryKeyTable = FullyQualifiedTableName { + .catalog = stmt.GetColumn(1), + .schema = stmt.GetColumn(2), + .table = stmt.GetColumn(3), + }; + auto foreignKeyTable = FullyQualifiedTableColumn { + .table = + FullyQualifiedTableName { + .catalog = stmt.GetColumn(5), + .schema = stmt.GetColumn(6), + .table = stmt.GetColumn(7), + }, + .column = stmt.GetColumn(8), + }; + ColumnList& keyColumns = constraints[{ primaryKeyTable, foreignKeyTable }]; + auto const sequenceNumber = stmt.GetColumn(9); + if (sequenceNumber > keyColumns.size()) + keyColumns.resize(sequenceNumber); + keyColumns[sequenceNumber - 1] = stmt.GetColumn(4); + } + + auto result = std::vector(); + for (auto const& [keyPair, columns]: constraints) + { + result.emplace_back(ForeignKeyConstraint { + .foreignKey = keyPair.second, + .primaryKey = { + .table = keyPair.first, + .columns = columns, + }, + }); + } + return result; + } + + std::vector AllForeignKeysTo(FullyQualifiedTableName const& table) + { + return AllForeignKeys(table, FullyQualifiedTableName {}); + } + + std::vector AllForeignKeysFrom(FullyQualifiedTableName const& table) + { + return AllForeignKeys(FullyQualifiedTableName {}, table); + } + + std::vector AllPrimaryKeys(FullyQualifiedTableName const& table) + { + std::vector keys; + std::vector sequenceNumbers; + + auto stmt = SqlStatement(); + + auto sqlResult = SQLPrimaryKeys(stmt.NativeHandle(), + (SQLCHAR*) table.catalog.data(), + (SQLSMALLINT) table.catalog.size(), + (SQLCHAR*) table.schema.data(), + (SQLSMALLINT) table.schema.size(), + (SQLCHAR*) table.table.data(), + (SQLSMALLINT) table.table.size()); + if (!SQL_SUCCEEDED(sqlResult)) + throw std::runtime_error( + std::format("SQLPrimaryKeys failed: {}", SqlErrorInfo::fromStatementHandle(stmt.NativeHandle()))); + + while (stmt.FetchRow()) + { + keys.emplace_back(stmt.GetColumn(4)); + sequenceNumbers.emplace_back(stmt.GetColumn(5)); + } + + std::vector sortedKeys; + sortedKeys.resize(keys.size()); + for (size_t i = 0; i < keys.size(); ++i) + sortedKeys.at(sequenceNumbers[i] - 1) = keys[i]; + + return sortedKeys; + } + +} // namespace + +void ReadAllTables(std::string_view database, std::string_view schema, EventHandler& eventHandler) +{ + auto const tableNames = AllTables(database, schema); + + for (auto& tableName: tableNames) + { + if (tableName == "sqlite_sequence") + continue; + + if (!eventHandler.OnTable(tableName)) + continue; + + auto const fullyQualifiedTableName = FullyQualifiedTableName { + .catalog = std::string(database), + .schema = std::string(schema), + .table = std::string(tableName), + }; + + auto const primaryKeys = AllPrimaryKeys(fullyQualifiedTableName); + eventHandler.OnPrimaryKeys(tableName, primaryKeys); + + auto const foreignKeys = AllForeignKeysFrom(fullyQualifiedTableName); + auto const incomingForeignKeys = AllForeignKeysTo(fullyQualifiedTableName); + + for (auto const& foreignKey: foreignKeys) + eventHandler.OnForeignKey(foreignKey); + + for (auto const& foreignKey: incomingForeignKeys) + eventHandler.OnExternalForeignKey(foreignKey); + + auto columnStmt = SqlStatement(); + auto const sqlResult = SQLColumns(columnStmt.NativeHandle(), + (SQLCHAR*) database.data(), + (SQLSMALLINT) database.size(), + (SQLCHAR*) schema.data(), + (SQLSMALLINT) schema.size(), + (SQLCHAR*) tableName.data(), + (SQLSMALLINT) tableName.size(), + nullptr /* column name */, + 0 /* column name length */); + if (!SQL_SUCCEEDED(sqlResult)) + throw std::runtime_error( + std::format("SQLColumns failed: {}", SqlErrorInfo::fromStatementHandle(columnStmt.NativeHandle()))); + + Column column; + + while (columnStmt.FetchRow()) + { + column.name = columnStmt.GetColumn(4); + column.type = FromNativeDataType(columnStmt.GetColumn(5)); + column.dialectDependantTypeString = columnStmt.GetColumn(6); + column.size = columnStmt.GetColumn(7); + // 8 - bufferLength + column.decimalDigits = columnStmt.GetColumn(9); + // 10 - NUM_PREC_RADIX + column.isNullable = columnStmt.GetColumn(11); + // 12 - remarks + column.defaultValue = columnStmt.GetColumn(13); + + // accumulated properties + column.isPrimaryKey = std::ranges::contains(primaryKeys, column.name); + // column.isForeignKey = ...; + column.isForeignKey = std::ranges::any_of( + foreignKeys, [&column](auto const& fk) { return fk.foreignKey.column == column.name; }); + if (auto const p = std::ranges::find_if( + incomingForeignKeys, [&column](auto const& fk) { return fk.foreignKey.column == column.name; }); + p != incomingForeignKeys.end()) + { + column.foreignKeyConstraint = *p; + } + + eventHandler.OnColumn(column); + } + + eventHandler.OnTableEnd(); + } +} + +std::string ToLowerCase(std::string_view str) +{ + std::string result(str); + std::transform( + result.begin(), result.end(), result.begin(), [](char c) { return static_cast(std::tolower(c)); }); + return result; +} + +TableList ReadAllTables(std::string_view database, std::string_view schema) +{ + TableList tables; + struct EventHandler: public SqlSchema::EventHandler + { + TableList& tables; + EventHandler(TableList& tables): + tables(tables) + { + } + + bool OnTable(std::string_view table) override + { + tables.emplace_back(Table { .name = std::string(table) }); + return true; + } + + void OnTableEnd() override {} + + void OnColumn(SqlSchema::Column const& column) override + { + tables.back().columns.emplace_back(column); + } + + void OnPrimaryKeys(std::string_view /*table*/, std::vector const& columns) override + { + tables.back().primaryKeys = columns; + } + + void OnForeignKey(SqlSchema::ForeignKeyConstraint const& foreignKeyConstraint) override + { + tables.back().foreignKeys.emplace_back(foreignKeyConstraint); + } + + void OnExternalForeignKey(SqlSchema::ForeignKeyConstraint const& foreignKeyConstraint) override + { + tables.back().externalForeignKeys.emplace_back(foreignKeyConstraint); + } + } eventHandler { tables }; + ReadAllTables(database, schema, eventHandler); + + std::map tableNameCaseMap; + for (auto const& table: tables) + tableNameCaseMap[ToLowerCase(table.name)] = table.name; + + // Fixup table names in foreign keys + // (Because at least Sqlite returns them in lowercase) + for (auto& table: tables) + { + for (auto& key: table.foreignKeys) + { + key.primaryKey.table.table = tableNameCaseMap.at(ToLowerCase(key.primaryKey.table.table)); + key.foreignKey.table.table = tableNameCaseMap.at(ToLowerCase(key.foreignKey.table.table)); + } + for (auto& key: table.externalForeignKeys) + { + key.primaryKey.table.table = tableNameCaseMap.at(ToLowerCase(key.primaryKey.table.table)); + key.foreignKey.table.table = tableNameCaseMap.at(ToLowerCase(key.foreignKey.table.table)); + } + } + + return tables; +} + +} // namespace SqlSchema diff --git a/src/Lightweight/SqlSchema.hpp b/src/Lightweight/SqlSchema.hpp new file mode 100644 index 00000000..3962ceb8 --- /dev/null +++ b/src/Lightweight/SqlSchema.hpp @@ -0,0 +1,181 @@ +#pragma once + +#if defined(_WIN32) || defined(_WIN64) + #include +#endif + +#include "SqlTraits.hpp" + +#include +#include +#include +#include + +namespace SqlSchema +{ + +namespace detail +{ + constexpr std::string_view rtrim(std::string_view value) noexcept + { + while (!value.empty() && (std::isspace(value.back()) || value.back() == '\0')) + value.remove_suffix(1); + return value; + } +} // namespace detail + +struct FullyQualifiedTableName +{ + std::string catalog {}; + std::string schema {}; + std::string table {}; + + bool operator==(FullyQualifiedTableName const& other) const noexcept + { + return catalog == other.catalog && schema == other.schema && table == other.table; + } + + bool operator!=(FullyQualifiedTableName const& other) const noexcept + { + return !(*this == other); + } + + bool operator<(FullyQualifiedTableName const& other) const noexcept + { + return std::tie(catalog, schema, table) < std::tie(other.catalog, other.schema, other.table); + } +}; + +struct FullyQualifiedTableColumn +{ + FullyQualifiedTableName table; + std::string column; + + bool operator==(FullyQualifiedTableColumn const& other) const noexcept + { + return table == other.table && column == other.column; + } + + bool operator!=(FullyQualifiedTableColumn const& other) const noexcept + { + return !(*this == other); + } + + bool operator<(FullyQualifiedTableColumn const& other) const noexcept + { + return std::tie(table, column) < std::tie(other.table, other.column); + } +}; + +struct FullyQualifiedTableColumnSequence +{ + FullyQualifiedTableName table; + std::vector columns; +}; + +struct ForeignKeyConstraint +{ + FullyQualifiedTableColumn foreignKey; + FullyQualifiedTableColumnSequence primaryKey; +}; + +struct Column +{ + std::string name = {}; + SqlColumnType type = SqlColumnType::UNKNOWN; + std::string dialectDependantTypeString = {}; + bool isNullable = true; + bool isUnique = false; + size_t size = 0; + unsigned short decimalDigits; + bool isAutoIncrement = false; + bool isPrimaryKey = false; + bool isForeignKey = false; + std::optional foreignKeyConstraint {}; + std::string defaultValue = {}; +}; + +class EventHandler +{ + public: + virtual ~EventHandler() = default; + + virtual bool OnTable(std::string_view table) = 0; + virtual void OnPrimaryKeys(std::string_view table, std::vector const& columns) = 0; + virtual void OnForeignKey(ForeignKeyConstraint const& foreignKeyConstraint) = 0; + virtual void OnColumn(Column const& column) = 0; + virtual void OnExternalForeignKey(ForeignKeyConstraint const& foreignKeyConstraint) = 0; + virtual void OnTableEnd() = 0; +}; + +void ReadAllTables(std::string_view database, std::string_view schema, EventHandler& eventHandler); + +struct Table +{ + // FullyQualifiedTableName name; + std::string name; + std::vector columns {}; + std::vector foreignKeys {}; + std::vector externalForeignKeys {}; + std::vector primaryKeys {}; +}; + +using TableList = std::vector; + +TableList ReadAllTables(std::string_view database, std::string_view schema = {}); + +} // namespace SqlSchema + +template <> +struct std::formatter: std::formatter +{ + auto format(SqlSchema::FullyQualifiedTableName const& value, format_context& ctx) const -> format_context::iterator + { + string output = std::string(SqlSchema::detail::rtrim(value.schema)); + if (!output.empty()) + output += '.'; + auto const trimmedSchema = SqlSchema::detail::rtrim(value.schema); + output += trimmedSchema; + if (!output.empty() && !trimmedSchema.empty()) + output += '.'; + output += SqlSchema::detail::rtrim(value.table); + return formatter::format(output, ctx); + } +}; + +template <> +struct std::formatter: std::formatter +{ + auto format(SqlSchema::FullyQualifiedTableColumn const& value, + format_context& ctx) const -> format_context::iterator + { + auto const table = std::format("{}", value.table); + if (table.empty()) + return formatter::format(std::format("{}", value.column), ctx); + else + return formatter::format(std::format("{}.{}", value.table, value.column), ctx); + } +}; + +template <> +struct std::formatter: std::formatter +{ + auto format(SqlSchema::FullyQualifiedTableColumnSequence const& value, + format_context& ctx) const -> format_context::iterator + { + auto const resolvedTableName = std::format("{}", value.table); + string output; + + for (auto const& column: value.columns) + { + if (!output.empty()) + output += ", "; + output += resolvedTableName; + if (!output.empty() && !resolvedTableName.empty()) + output += '.'; + output += column; + } + + return formatter::format(output, ctx); + } +}; diff --git a/src/Lightweight/SqlScopedTraceLogger.hpp b/src/Lightweight/SqlScopedTraceLogger.hpp new file mode 100644 index 00000000..4930671c --- /dev/null +++ b/src/Lightweight/SqlScopedTraceLogger.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include "SqlConnection.hpp" +#include "SqlStatement.hpp" + +#include + +// TODO: move to public API +class SqlScopedTraceLogger +{ + SQLHDBC m_nativeConnection; + + public: + explicit SqlScopedTraceLogger(SqlStatement& stmt): + SqlScopedTraceLogger(stmt.Connection().NativeHandle(), +#if defined(_WIN32) || defined(_WIN64) + "CONOUT$" +#else + "/dev/stdout" +#endif + ) + { + } + + explicit SqlScopedTraceLogger(SQLHDBC hDbc, std::filesystem::path const& logFile): + m_nativeConnection { hDbc } + { + SQLSetConnectAttrA(m_nativeConnection, SQL_ATTR_TRACEFILE, (SQLPOINTER) logFile.string().c_str(), SQL_NTS); + SQLSetConnectAttrA(m_nativeConnection, SQL_ATTR_TRACE, (SQLPOINTER) SQL_OPT_TRACE_ON, SQL_IS_UINTEGER); + } + + ~SqlScopedTraceLogger() + { + SQLSetConnectAttrA(m_nativeConnection, SQL_ATTR_TRACE, (SQLPOINTER) SQL_OPT_TRACE_OFF, SQL_IS_UINTEGER); + } +}; diff --git a/src/Lightweight/SqlStatement.cpp b/src/Lightweight/SqlStatement.cpp new file mode 100644 index 00000000..c23199ce --- /dev/null +++ b/src/Lightweight/SqlStatement.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +#include "SqlStatement.hpp" + +SqlStatement::SqlStatement() noexcept: + m_ownedConnection { SqlConnection() }, + m_connection { &m_ownedConnection.value() }, + m_lastError { m_connection->LastError() } +{ + if (m_lastError == SqlError::SUCCESS) + RequireSuccess(SQLAllocHandle(SQL_HANDLE_STMT, m_connection->NativeHandle(), &m_hStmt)); +} + +// Construct a new SqlStatement object, using the given connection. +SqlStatement::SqlStatement(SqlConnection& relatedConnection): + m_connection { &relatedConnection } +{ + RequireSuccess(SQLAllocHandle(SQL_HANDLE_STMT, m_connection->NativeHandle(), &m_hStmt)); +} + +SqlStatement::~SqlStatement() noexcept +{ + SQLFreeHandle(SQL_HANDLE_STMT, m_hStmt); +} + +void SqlStatement::Prepare(std::string_view query) +{ + SqlLogger::GetLogger().OnPrepare(query); + + m_postExecuteCallbacks.clear(); + m_postProcessOutputColumnCallbacks.clear(); + + // Closes the cursor if it is open + RequireSuccess(SQLFreeStmt(m_hStmt, SQL_CLOSE)); + + // Prepares the statement + RequireSuccess(SQLPrepareA(m_hStmt, (SQLCHAR*) query.data(), (SQLINTEGER) query.size())); + RequireSuccess(SQLNumParams(m_hStmt, &m_expectedParameterCount)); + m_indicators.resize(m_expectedParameterCount + 1); +} + +void SqlStatement::ExecuteDirect(const std::string_view& query, std::source_location location) +{ + if (query.empty()) + return; + + SqlLogger::GetLogger().OnExecuteDirect(query); + + RequireSuccess(SQLFreeStmt(m_hStmt, SQL_CLOSE), location); + RequireSuccess(SQLExecDirectA(m_hStmt, (SQLCHAR*) query.data(), (SQLINTEGER) query.size()), location); +} + +// Retrieves the number of rows affected by the last query. +size_t SqlStatement::NumRowsAffected() const +{ + SQLLEN numRowsAffected {}; + RequireSuccess(SQLRowCount(m_hStmt, &numRowsAffected)); + return numRowsAffected; +} + +// Retrieves the number of columns affected by the last query. +size_t SqlStatement::NumColumnsAffected() const +{ + SQLSMALLINT numColumns {}; + RequireSuccess(SQLNumResultCols(m_hStmt, &numColumns)); + return numColumns; +} + +// Retrieves the last insert ID of the last query's primary key. +size_t SqlStatement::LastInsertId() +{ + return ExecuteDirectScalar(m_connection->Traits().LastInsertIdQuery).value_or(0); +} + +// Fetches the next row of the result set. +bool SqlStatement::FetchRow() +{ + auto const sqlResult = SQLFetch(m_hStmt); + switch (sqlResult) + { + case SQL_NO_DATA: + return false; + default: + RequireSuccess(sqlResult); + // post-process the output columns, if needed + for (auto const& postProcess: m_postProcessOutputColumnCallbacks) + postProcess(); + m_postProcessOutputColumnCallbacks.clear(); + return true; + } +} + +void SqlStatement::RequireSuccess(SQLRETURN error, std::source_location sourceLocation) const +{ + auto result = detail::UpdateSqlError(&m_lastError, error); + if (result.has_value()) + return; + + auto errorInfo = SqlErrorInfo::fromStatementHandle(m_hStmt); + SqlLogger::GetLogger().OnError(m_lastError, errorInfo, sourceLocation); + if (errorInfo.sqlState == "07009") + throw std::invalid_argument(std::format("SQL error: {}", errorInfo)); + else + throw std::runtime_error(std::format("SQL error: {}", errorInfo)); +} diff --git a/src/Lightweight/SqlStatement.hpp b/src/Lightweight/SqlStatement.hpp new file mode 100644 index 00000000..64f2241f --- /dev/null +++ b/src/Lightweight/SqlStatement.hpp @@ -0,0 +1,347 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#if defined(_WIN32) || defined(_WIN64) + #include +#endif + +#include "SqlConnection.hpp" +#include "SqlDataBinder.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// High level API for (prepared) raw SQL statements +// +// SQL prepared statement lifecycle: +// 1. Prepare the statement +// 2. Optionally bind output columns to local variables +// 3. Execute the statement (optionally with input parameters) +// 4. Fetch rows (if any) +// 5. Repeat steps 3 and 4 as needed +class SqlStatement final: public SqlDataBinderCallback +{ + public: + // Construct a new SqlStatement object, using a new connection, and connect to the default database. + SqlStatement() noexcept; + + SqlStatement(SqlStatement&&) noexcept = default; + + // Construct a new SqlStatement object, using the given connection. + SqlStatement(SqlConnection& relatedConnection); + + ~SqlStatement() noexcept final; + + // Retrieves the connection associated with this statement. + SqlConnection& Connection() noexcept; + + // Retrieves the connection associated with this statement. + SqlConnection const& Connection() const noexcept; + + // Retrieves the native handle of the statement. + [[nodiscard]] SQLHSTMT NativeHandle() const noexcept; + + // Retrieves the last error code. + [[nodiscard]] SqlError LastError() const noexcept; + + // Prepares the statement for execution. + void Prepare(std::string_view query); + + template + void BindInputParameter(SQLSMALLINT columnIndex, Arg const& arg); + + // Binds the given arguments to the prepared statement to store the fetched data to. + // + // The statement must be prepared before calling this function. + template + void BindOutputColumns(Args*... args); + + template + void BindOutputColumn(SQLUSMALLINT columnIndex, T* arg); + + // Binds the given arguments to the prepared statement and executes it. + template + void Execute(Args const&... args); + + // Executes the prepared statement on a a batch of data. + // + // Each parameter represents a column, to be bound as input parameter. + // The element types of each column container must be explicitly supported. + // + // In order to support column value types, their underlying storage must be contiguous. + // Also the input range itself must be contiguous. + // If any of these conditions are not met, the function will not compile - use ExecuteBatch() instead. + template + void ExecuteBatchNative(FirstColumnBatch const& firstColumnBatch, MoreColumnBatches const&... moreColumnBatches); + + // Executes the prepared statement on a a batch of data. + // + // Each parameter represents a column, to be bound as input parameter, + // and the number of elements in these bound column containers will + // mandate how many executions will happen. + template + void ExecuteBatch(FirstColumnBatch const& firstColumnBatch, MoreColumnBatches const&... moreColumnBatches); + + // Executes the given query directly. + void ExecuteDirect(const std::string_view& query, std::source_location location = std::source_location::current()); + + // Executes the given query, assuming that only one result row and column is affected, that one will be + // returned. + template + [[nodiscard]] std::optional ExecuteDirectScalar(const std::string_view& query, + std::source_location location = std::source_location::current()); + + // Retrieves the number of rows affected by the last query. + [[nodiscard]] size_t NumRowsAffected() const; + + // Retrieves the number of columns affected by the last query. + [[nodiscard]] size_t NumColumnsAffected() const; + + // Retrieves the last insert ID of the last query's primary key. + [[nodiscard]] size_t LastInsertId(); + + // Fetches the next row of the result set. + [[nodiscard]] bool FetchRow(); + + // Retrieves the value of the column at the given index for the currently selected row. + template + void GetColumn(SQLUSMALLINT column, T* result) const; + + // Retrieves the value of the column at the given index for the currently selected row. + template + [[nodiscard]] T GetColumn(SQLUSMALLINT column) const; + + private: + void RequireSuccess(SQLRETURN error, std::source_location sourceLocation = std::source_location::current()) const; + void PlanPostExecuteCallback(std::function&& cb) override; + void PlanPostProcessOutputColumn(std::function&& cb) override; + void ProcessPostExecuteCallbacks(); + + // private data members + + std::optional m_ownedConnection; // The connection object (if owned) + SqlConnection* m_connection {}; // Pointer to the connection object + SQLHSTMT m_hStmt {}; // The native oDBC statement handle + mutable SqlError m_lastError {}; // The last error code + SQLSMALLINT m_expectedParameterCount {}; // The number of parameters expected by the query + std::vector m_indicators; // Holds the indicators for the bound output columns + std::vector> m_postExecuteCallbacks; + std::vector> m_postProcessOutputColumnCallbacks; +}; + +// {{{ inline implementation +inline SqlConnection& SqlStatement::Connection() noexcept +{ + return *m_connection; +} + +inline SqlConnection const& SqlStatement::Connection() const noexcept +{ + return *m_connection; +} + +[[nodiscard]] inline SQLHSTMT SqlStatement::NativeHandle() const noexcept +{ + return m_hStmt; +} + +[[nodiscard]] inline SqlError SqlStatement::LastError() const noexcept +{ + return m_lastError; +} + +template +void SqlStatement::BindOutputColumns(Args*... args) +{ + auto const numColumns = NumColumnsAffected(); + m_indicators.resize(numColumns + 1); + + SQLUSMALLINT i = 0; + ((++i, SqlDataBinder::OutputColumn(m_hStmt, i, args, &m_indicators[i], *this)), ...); +} + +template +void SqlStatement::BindOutputColumn(SQLUSMALLINT columnIndex, T* arg) +{ + if (m_indicators.size() <= columnIndex) + m_indicators.resize(NumColumnsAffected() + 1); + + SqlDataBinder::OutputColumn(m_hStmt, columnIndex, arg, &m_indicators[columnIndex], *this); +} + +template +void SqlStatement::BindInputParameter(SQLSMALLINT columnIndex, Arg const& arg) +{ + // tell Execute() that we don't know the expected count + m_expectedParameterCount = (std::numeric_limits::max)(); + RequireSuccess(SqlDataBinder::InputParameter(m_hStmt, columnIndex, arg)); +} + +template +void SqlStatement::Execute(Args const&... args) +{ + // Each input parameter must have an address, + // such that we can call SQLBindParameter() without needing to copy it. + // The memory region behind the input parameter must exist until the SQLExecute() call. + + SqlLogger::GetLogger().OnExecute(); + + if (!(m_expectedParameterCount == (std::numeric_limits::max)() + && sizeof...(args) == 0) + && !(m_expectedParameterCount == sizeof...(args))) + throw std::invalid_argument { "Invalid argument count" }; + + SQLUSMALLINT i = 0; + ((++i, SqlDataBinder::InputParameter(m_hStmt, i, args)), ...); + + RequireSuccess(SQLExecute(m_hStmt)); + ProcessPostExecuteCallbacks(); +} + +// clang-format off +template +concept SqlNativeContiguousValueConcept = + std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as + || std::same_as>; + +template +concept SqlNativeBatchable = + std::ranges::contiguous_range + && (std::ranges::contiguous_range && ...) + && SqlNativeContiguousValueConcept> + && (SqlNativeContiguousValueConcept> && ...); + +// clang-format on + +template +void SqlStatement::ExecuteBatchNative(FirstColumnBatch const& firstColumnBatch, + MoreColumnBatches const&... moreColumnBatches) +{ + static_assert(SqlNativeBatchable, + "Must be a supported native contiguous element type."); + + if (m_expectedParameterCount != 1 + sizeof...(moreColumnBatches)) + throw std::invalid_argument { "Invalid number of columns" }; + + const auto rowCount = std::ranges::size(firstColumnBatch); + if (!((std::size(moreColumnBatches) == rowCount) && ...)) + throw std::invalid_argument { "Uneven number of rows" }; + + size_t rowStart = 0; + + // clang-format off + RequireSuccess(SQLSetStmtAttr(m_hStmt, SQL_ATTR_PARAMSET_SIZE, (SQLPOINTER) rowCount, 0)); + RequireSuccess(SQLSetStmtAttr(m_hStmt, SQL_ATTR_PARAM_BIND_OFFSET_PTR, &rowStart, 0)); + RequireSuccess(SQLSetStmtAttr(m_hStmt, SQL_ATTR_PARAM_BIND_TYPE, SQL_PARAM_BIND_BY_COLUMN, 0)); + RequireSuccess(SQLSetStmtAttr(m_hStmt, SQL_ATTR_PARAM_OPERATION_PTR, SQL_PARAM_PROCEED, 0)); + RequireSuccess(SqlDataBinder>:: + InputParameter(m_hStmt, 1, *std::ranges::data(firstColumnBatch))); + SQLUSMALLINT column = 1; + (RequireSuccess(SqlDataBinder>:: + InputParameter(m_hStmt, ++column, *std::ranges::data(moreColumnBatches))), ...); + RequireSuccess(SQLExecute(m_hStmt)); + ProcessPostExecuteCallbacks(); + // clang-format on +} + +template +void SqlStatement::ExecuteBatch(FirstColumnBatch const& firstColumnBatch, MoreColumnBatches const&... moreColumnBatches) +{ + // If the input ranges are contiguous and their element types are contiguous and supported as well, + // we can use the native batch execution. + if constexpr (SqlNativeBatchable) + { + ExecuteBatchNative(firstColumnBatch, moreColumnBatches...); + return; + } + + if (m_expectedParameterCount != 1 + sizeof...(moreColumnBatches)) + throw std::invalid_argument { "Invalid number of columns" }; + + const auto rowCount = std::ranges::size(firstColumnBatch); + if (!((std::size(moreColumnBatches) == rowCount) && ...)) + throw std::invalid_argument { "Uneven number of rows" }; + + m_lastError = SqlError::SUCCESS; + for (auto const rowIndex: std::views::iota(size_t { 0 }, rowCount)) + { + std::apply( + [&](ColumnValues const&... columnsInRow) { + SQLUSMALLINT column = 0; + ((++column, SqlDataBinder::InputParameter(m_hStmt, column, columnsInRow)), ...); + RequireSuccess(SQLExecute(m_hStmt)); + ProcessPostExecuteCallbacks(); + }, + std::make_tuple(std::ref(*std::ranges::next(std::ranges::begin(firstColumnBatch), rowIndex)), + std::ref(*std::ranges::next(std::ranges::begin(moreColumnBatches), rowIndex))...)); + } +} + +template +inline void SqlStatement::GetColumn(SQLUSMALLINT column, T* result) const +{ + SQLLEN indicator {}; // TODO: Handle NULL values if we find out that we need them for our use-cases. + RequireSuccess(SqlDataBinder::GetColumn(m_hStmt, column, result, &indicator)); +} + +template +[[nodiscard]] T SqlStatement::GetColumn(SQLUSMALLINT column) const +{ + T result {}; + SQLLEN indicator {}; // TODO: Handle NULL values if we find out that we need them for our use-cases. + RequireSuccess(SqlDataBinder::GetColumn(m_hStmt, column, &result, &indicator)); + return result; +} + +inline void SqlStatement::PlanPostExecuteCallback(std::function&& cb) +{ + m_postExecuteCallbacks.emplace_back(std::move(cb)); +} + +inline void SqlStatement::ProcessPostExecuteCallbacks() +{ + for (auto& cb: m_postExecuteCallbacks) + cb(); + m_postExecuteCallbacks.clear(); +} + +inline void SqlStatement::PlanPostProcessOutputColumn(std::function&& cb) +{ + m_postProcessOutputColumnCallbacks.emplace_back(std::move(cb)); +} + +template +std::optional SqlStatement::ExecuteDirectScalar(const std::string_view& query, std::source_location location) +{ + ExecuteDirect(query, location); + if (FetchRow()) + return { GetColumn(1) }; + return std::nullopt; +} + +// }}} diff --git a/src/Lightweight/SqlTraits.hpp b/src/Lightweight/SqlTraits.hpp new file mode 100644 index 00000000..3f233077 --- /dev/null +++ b/src/Lightweight/SqlTraits.hpp @@ -0,0 +1,188 @@ +#pragma once + +#include +#include +#include +#include +#include + +// Represents the type of SQL server, used to determine the correct SQL syntax, if needed. +enum class SqlServerType : uint8_t +{ + UNKNOWN, + MICROSOFT_SQL, + POSTGRESQL, + ORACLE, + SQLITE, + MYSQL, +}; + +enum class SqlColumnType : uint8_t +{ + UNKNOWN, + CHAR, + STRING, + TEXT, + BOOLEAN, + INTEGER, + REAL, + BLOB, + DATE, + TIME, + DATETIME, +}; + +namespace detail +{ + +constexpr std::string_view DefaultColumnTypeName(SqlColumnType value) noexcept +{ + switch (value) + { + case SqlColumnType::CHAR: + return "CHAR"; + case SqlColumnType::STRING: + return "VARCHAR(255)"; // FIXME: This is a guess. Define and use column width somewhere + case SqlColumnType::TEXT: + return "TEXT"; + case SqlColumnType::BOOLEAN: + return "BOOL"; + case SqlColumnType::INTEGER: + return "INTEGER"; + case SqlColumnType::REAL: + return "REAL"; + case SqlColumnType::BLOB: + return "BLOB"; + case SqlColumnType::DATE: + return "DATE"; + case SqlColumnType::TIME: + return "TIME"; + case SqlColumnType::DATETIME: + return "DATETIME"; + case SqlColumnType::UNKNOWN: + break; + } + return "UNKNOWN"; +} + +constexpr std::string_view MSSqlColumnTypeName(SqlColumnType value) noexcept +{ + switch (value) + { + case SqlColumnType::TEXT: + return "VARCHAR(MAX)"; + case SqlColumnType::BOOLEAN: + return "BIT"; + default: + return DefaultColumnTypeName(value); + } +} + +} // namespace detail + +struct SqlTraits +{ + std::string_view LastInsertIdQuery; + std::string_view PrimaryKeyAutoIncrement; + std::string_view CurrentTimestampExpr; + std::string_view EnforceForeignKeyConstraint {}; + size_t MaxStatementLength {}; + std::function ColumnTypeName {}; +}; + +namespace detail +{ + +inline SqlTraits const MicrosoftSqlTraits { + .LastInsertIdQuery = "SELECT @@IDENTITY;", + .PrimaryKeyAutoIncrement = "INT IDENTITY(1,1) PRIMARY KEY", + .CurrentTimestampExpr = "GETDATE()", + .ColumnTypeName = detail::MSSqlColumnTypeName, +}; + +inline SqlTraits const PostgresSqlTraits { + .LastInsertIdQuery = "SELECT LASTVAL()", + .PrimaryKeyAutoIncrement = "SERIAL PRIMARY KEY", + .CurrentTimestampExpr = "CURRENT_TIMESTAMP", + .ColumnTypeName = [](SqlColumnType value) -> std::string_view { + switch (value) + { + case SqlColumnType::DATETIME: + return "TIMESTAMP"; + default: + return detail::DefaultColumnTypeName(value); + } + }, +}; + +inline SqlTraits const OracleSqlTraits { + .LastInsertIdQuery = "SELECT LAST_INSERT_ID() FROM DUAL", + .PrimaryKeyAutoIncrement = "NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY PRIMARY KEY", + .CurrentTimestampExpr = "SYSTIMESTAMP", + .ColumnTypeName = detail::DefaultColumnTypeName, +}; + +inline SqlTraits const SQLiteTraits { + .LastInsertIdQuery = "SELECT LAST_INSERT_ROWID()", + .PrimaryKeyAutoIncrement = "INTEGER PRIMARY KEY AUTOINCREMENT", + .CurrentTimestampExpr = "CURRENT_TIMESTAMP", + .EnforceForeignKeyConstraint = "PRAGMA foreign_keys = ON", + .ColumnTypeName = detail::DefaultColumnTypeName, +}; + +inline SqlTraits const MySQLTraits { + .LastInsertIdQuery = "SELECT LAST_INSERT_ID()", + .PrimaryKeyAutoIncrement = "INT AUTO_INCREMENT PRIMARY KEY", + .CurrentTimestampExpr = "NOW()", + .ColumnTypeName = detail::DefaultColumnTypeName, +}; + +inline SqlTraits const UnknownSqlTraits { + .LastInsertIdQuery = "", + .PrimaryKeyAutoIncrement = "", + .CurrentTimestampExpr = "", + .ColumnTypeName = detail::DefaultColumnTypeName, +}; + +} // namespace detail + +inline SqlTraits const& GetSqlTraits(SqlServerType serverType) noexcept +{ + auto static const sqlTraits = std::array { + &detail::UnknownSqlTraits, &detail::MicrosoftSqlTraits, &detail::PostgresSqlTraits, + &detail::OracleSqlTraits, &detail::SQLiteTraits, + }; + + return *sqlTraits[static_cast(serverType)]; +} + +template <> +struct std::formatter: std::formatter +{ + auto format(SqlServerType type, format_context& ctx) const -> format_context::iterator + { + string_view name; + switch (type) + { + case SqlServerType::MICROSOFT_SQL: + name = "Microsoft SQL Server"; + break; + case SqlServerType::POSTGRESQL: + name = "PostgreSQL"; + break; + case SqlServerType::ORACLE: + name = "Oracle"; + break; + case SqlServerType::SQLITE: + name = "SQLite"; + break; + case SqlServerType::MYSQL: + name = "MySQL"; + break; + case SqlServerType::UNKNOWN: + name = "Unknown"; + break; + } + return std::formatter::format(name, ctx); + } +}; diff --git a/src/Lightweight/SqlTransaction.cpp b/src/Lightweight/SqlTransaction.cpp new file mode 100644 index 00000000..40f75ad4 --- /dev/null +++ b/src/Lightweight/SqlTransaction.cpp @@ -0,0 +1,53 @@ +#include "SqlConnection.hpp" +#include "SqlTransaction.hpp" + +SqlTransaction::SqlTransaction(SqlConnection& connection, SqlTransactionMode defaultMode) noexcept: + m_hDbc { connection.NativeHandle() }, + m_defaultMode { defaultMode } +{ + SQLSetConnectAttr(m_hDbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER) SQL_AUTOCOMMIT_OFF, SQL_IS_UINTEGER); +} + +SqlTransaction::~SqlTransaction() +{ + switch (m_defaultMode) + { + case SqlTransactionMode::NONE: + break; + case SqlTransactionMode::COMMIT: + Commit(); + break; + case SqlTransactionMode::ROLLBACK: + Rollback(); + break; + } +} + +std::expected SqlTransaction::Rollback() +{ + SQLRETURN sqlReturn = SQLEndTran(SQL_HANDLE_DBC, m_hDbc, SQL_ROLLBACK); + if (sqlReturn != SQL_SUCCESS && sqlReturn != SQL_SUCCESS_WITH_INFO) + return std::unexpected { SqlErrorInfo::fromConnectionHandle(m_hDbc) }; + ; + sqlReturn = SQLSetConnectAttr(m_hDbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER) SQL_AUTOCOMMIT_ON, SQL_IS_UINTEGER); + if (sqlReturn != SQL_SUCCESS && sqlReturn != SQL_SUCCESS_WITH_INFO) + return std::unexpected { SqlErrorInfo::fromConnectionHandle(m_hDbc) }; + + m_defaultMode = SqlTransactionMode::NONE; + return {}; +} + +// Commit the transaction +std::expected SqlTransaction::Commit() +{ + SQLRETURN sqlReturn = SQLEndTran(SQL_HANDLE_DBC, m_hDbc, SQL_COMMIT); + if (sqlReturn != SQL_SUCCESS && sqlReturn != SQL_SUCCESS_WITH_INFO) + return std::unexpected { SqlErrorInfo::fromConnectionHandle(m_hDbc) }; + + sqlReturn = SQLSetConnectAttr(m_hDbc, SQL_ATTR_AUTOCOMMIT, (SQLPOINTER) SQL_AUTOCOMMIT_ON, SQL_IS_UINTEGER); + if (sqlReturn != SQL_SUCCESS && sqlReturn != SQL_SUCCESS_WITH_INFO) + return std::unexpected { SqlErrorInfo::fromConnectionHandle(m_hDbc) }; + + m_defaultMode = SqlTransactionMode::NONE; + return {}; +} diff --git a/src/Lightweight/SqlTransaction.hpp b/src/Lightweight/SqlTransaction.hpp new file mode 100644 index 00000000..7d3461e9 --- /dev/null +++ b/src/Lightweight/SqlTransaction.hpp @@ -0,0 +1,52 @@ +#pragma once + +#if defined(_WIN32) || defined(_WIN64) + #include +#endif + +#include "SqlConcepts.hpp" +#include "SqlError.hpp" + +#include +#include +#include +#include + +class SqlConnection; + +// Represents the mode of a SQL transaction to be applied, if not done so explicitly. +enum class SqlTransactionMode +{ + NONE, + COMMIT, + ROLLBACK, +}; + +// Represents a transaction to a SQL database. +// +// This class is used to control the transaction manually. It disables the auto-commit mode when constructed, +// and automatically commits the transaction when destructed if not done so. +// +// This class is designed with RAII in mind, so that the transaction is automatically committed or rolled back +// when the object goes out of scope. +class SqlTransaction +{ + public: + // Construct a new SqlTransaction object, and disable the auto-commit mode, so that the transaction can be + // controlled manually. + explicit SqlTransaction(SqlConnection& connection, + SqlTransactionMode defaultMode = SqlTransactionMode::COMMIT) noexcept; + + // Automatically commit the transaction if not done so + ~SqlTransaction(); + + // Rollback the transaction + std::expected Rollback(); + + // Commit the transaction + std::expected Commit(); + + private: + SQLHDBC m_hDbc; + SqlTransactionMode m_defaultMode; +}; diff --git a/src/Lightweight/SqlUtils.hpp b/src/Lightweight/SqlUtils.hpp new file mode 100644 index 00000000..03307138 --- /dev/null +++ b/src/Lightweight/SqlUtils.hpp @@ -0,0 +1,22 @@ +#pragma once +#if defined(_WIN32) || defined(_WIN64) + #include +#endif + +#include "SqlConnection.hpp" +#include "SqlError.hpp" +#include "SqlStatement.hpp" + +#include +#include +#include +#include + +namespace SqlUtils +{ + +SqlResult> TableNames(std::string_view database, std::string_view schema = {}); + +SqlResult> ColumnNames(std::string_view tableName, std::string_view schema = {}); + +} // namespace SqlUtils diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt new file mode 100644 index 00000000..082965cc --- /dev/null +++ b/src/tests/CMakeLists.txt @@ -0,0 +1,22 @@ +find_package(Catch2 REQUIRED) + +add_executable(LightweightTest) +target_compile_features(LightweightTest PUBLIC cxx_std_23) + +set(TEST_LIBRARIES Catch2::Catch2 Lightweight::Lightweight) +if(MSVC) + target_compile_options(LightweightTest PRIVATE /MP) +else() + set(TEST_LIBRARIES ${TEST_LIBRARIES} odbc ${SQLITE3_LIBRARY}) # FIXME: should be PkgConfig::ODBC in Lightweight target already +endif() + +target_link_libraries(LightweightTest PRIVATE ${TEST_LIBRARIES}) + +target_sources(LightweightTest PRIVATE + LightweightTests.cpp + ModelTests.cpp + ModelAssociationsTests.cpp +) + +enable_testing() +add_test(NAME LightweightTest COMMAND LightweightTest) diff --git a/src/tests/LightweightTests.cpp b/src/tests/LightweightTests.cpp new file mode 100644 index 00000000..e2d0617a --- /dev/null +++ b/src/tests/LightweightTests.cpp @@ -0,0 +1,837 @@ +#include "../Lightweight/SqlScopedTraceLogger.hpp" +#include "Utils.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) + // Disable the warning C4834: discarding return value of function with 'nodiscard' attribute. + // Because we are simply testing and demonstrating the library and not using it in production code. + #pragma warning(disable : 4834) +#endif + +using namespace std::string_view_literals; + +int main(int argc, char** argv) +{ + auto result = SqlTestFixture::Initialize(argc, argv); + if (!result.has_value()) + return result.error(); + + std::tie(argc, argv) = result.value(); + + struct finally + { + ~finally() + { + SqlLogger::GetLogger().OnStats(SqlConnection::Stats()); + } + } _; + + return Catch::Session().run(argc, argv); +} + +namespace +{ + +void CreateEmployeesTable(SqlStatement& stmt, std::source_location sourceLocation = std::source_location::current()) +{ + stmt.ExecuteDirect(std::format(R"SQL(CREATE TABLE Employees ( + EmployeeID {}, + FirstName VARCHAR(50) NOT NULL, + LastName VARCHAR(50), + Salary INT NOT NULL + ); + )SQL", + stmt.Connection().Traits().PrimaryKeyAutoIncrement), + sourceLocation); +} + +void FillEmployeesTable(SqlStatement& stmt) +{ + stmt.Prepare("INSERT INTO Employees (FirstName, LastName, Salary) VALUES (?, ?, ?)"); + stmt.Execute("Alice", "Smith", 50'000); + stmt.Execute("Bob", "Johnson", 60'000); + stmt.Execute("Charlie", "Brown", 70'000); +} + +} // namespace + +TEST_CASE_METHOD(SqlTestFixture, "SqlFixedString: resize and clear") +{ + SqlFixedString<8> str; + + REQUIRE(str.size() == 0); + REQUIRE(str.empty()); + + str.resize(1, 'x'); + REQUIRE(!str.empty()); + REQUIRE(str.size() == 1); + REQUIRE(str == "x"); + + str.resize(4, 'y'); + REQUIRE(str.size() == 4); + REQUIRE(str == "xyyy"); + + // one-off overflow truncates + str.resize(9, 'z'); + REQUIRE(str.size() == 8); + REQUIRE(str == "xyyyzzzz"); + + // resize down + str.resize(2); + REQUIRE(str.size() == 2); + REQUIRE(str == "xy"); + + str.clear(); + REQUIRE(str.size() == 0); + REQUIRE(str == ""); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlFixedString: push_back and pop_back") +{ + SqlFixedString<2> str; + + str.push_back('a'); + str.push_back('b'); + REQUIRE(str == "ab"); + + // overflow: no-op (truncates) + str.push_back('c'); + REQUIRE(str == "ab"); + + str.pop_back(); + REQUIRE(str == "a"); + + str.pop_back(); + REQUIRE(str == ""); + + // no-op + str.pop_back(); + REQUIRE(str == ""); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlFixedString: assign") +{ + SqlFixedString<12> str; + str.assign("Hello, World"); + REQUIRE(str == "Hello, World"); + // str.assign("Hello, World!"); <-- would fail due to static_assert + str.assign("Hello, World!"sv); + REQUIRE(str == "Hello, World"); + + str = "Something"; + REQUIRE(str == "Something"); + // str = ("Hello, World!"); // <-- would fail due to static_assert +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlFixedString: c_str") +{ + SqlFixedString<12> str { "Hello, World" }; + str.resize(5); + REQUIRE(str.data()[5] == ','); + + SqlFixedString<12> const& constStr = str; + REQUIRE(constStr.c_str() == "Hello"sv); // Call to `c_str() const` also mutates [5] to NUL + REQUIRE(str.data()[5] == '\0'); + + str.resize(2); + REQUIRE(str.data()[2] == 'l'); + REQUIRE(str.c_str() == "He"sv); // Call to `c_str()` also mutates [2] to NUL + REQUIRE(str.data()[2] == '\0'); +} + +TEST_CASE_METHOD(SqlTestFixture, "select: get columns") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("SELECT 42"); + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == 42); + REQUIRE(!stmt.FetchRow()); +} + +TEST_CASE_METHOD(SqlTestFixture, "select: get column (invalid index)") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("SELECT 42"); + REQUIRE(stmt.FetchRow()); + + auto const _ = ScopedSqlNullLogger {}; // suppress the error message, we are testing for it + + CHECK_THROWS_AS(stmt.GetColumn(2), std::invalid_argument); + REQUIRE(!stmt.FetchRow()); +} + +TEST_CASE_METHOD(SqlTestFixture, "execute bound parameters and select back: VARCHAR, INT") +{ + auto stmt = SqlStatement {}; + CreateEmployeesTable(stmt); + + stmt.Prepare("INSERT INTO Employees (FirstName, LastName, Salary) VALUES (?, ?, ?)"); + stmt.Execute("Alice", "Smith", 50'000); + stmt.Execute("Bob", "Johnson", 60'000); + stmt.Execute("Charlie", "Brown", 70'000); + + stmt.ExecuteDirect("SELECT COUNT(*) FROM Employees"); + REQUIRE(stmt.NumColumnsAffected() == 1); + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == 3); + REQUIRE(!stmt.FetchRow()); + + stmt.Prepare("SELECT FirstName, LastName, Salary FROM Employees WHERE Salary >= ?"); + REQUIRE(stmt.NumColumnsAffected() == 3); + stmt.Execute(55'000); + + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == "Bob"); + REQUIRE(stmt.GetColumn(2) == "Johnson"); + REQUIRE(stmt.GetColumn(3) == 60'000); + + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == "Charlie"); + REQUIRE(stmt.GetColumn(2) == "Brown"); + REQUIRE(stmt.GetColumn(3) == 70'000); + + REQUIRE(!stmt.FetchRow()); +} + +TEST_CASE_METHOD(SqlTestFixture, "transaction: auto-rollback") +{ + auto stmt = SqlStatement {}; + REQUIRE(stmt.Connection().TransactionsAllowed()); + CreateEmployeesTable(stmt); + + { + auto transaction = SqlTransaction { stmt.Connection(), SqlTransactionMode::ROLLBACK }; + stmt.Prepare("INSERT INTO Employees (FirstName, LastName, Salary) VALUES (?, ?, ?)"); + stmt.Execute("Alice", "Smith", 50'000); + REQUIRE(stmt.Connection().TransactionActive()); + } + // transaction automatically rolled back + + REQUIRE(!stmt.Connection().TransactionActive()); + stmt.ExecuteDirect("SELECT COUNT(*) FROM Employees"); + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == 0); +} + +TEST_CASE_METHOD(SqlTestFixture, "transaction: auto-commit") +{ + auto stmt = SqlStatement {}; + REQUIRE(stmt.Connection().TransactionsAllowed()); + CreateEmployeesTable(stmt); + + { + auto transaction = SqlTransaction { stmt.Connection(), SqlTransactionMode::COMMIT }; + stmt.Prepare("INSERT INTO Employees (FirstName, LastName, Salary) VALUES (?, ?, ?)"); + stmt.Execute("Alice", "Smith", 50'000); + REQUIRE(stmt.Connection().TransactionActive()); + } + // transaction automatically committed + + REQUIRE(!stmt.Connection().TransactionActive()); + stmt.ExecuteDirect("SELECT COUNT(*) FROM Employees"); + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == 1); +} + +TEST_CASE_METHOD(SqlTestFixture, "execute binding output parameters (direct)") +{ + auto stmt = SqlStatement {}; + CreateEmployeesTable(stmt); + FillEmployeesTable(stmt); + + std::string firstName(20, '\0'); // pre-allocation for output parameter strings is important + std::string lastName(20, '\0'); // ditto + unsigned int salary {}; + + stmt.Prepare("SELECT FirstName, LastName, Salary FROM Employees WHERE Salary = ?"); + stmt.BindOutputColumns(&firstName, &lastName, &salary); + stmt.Execute(50'000); + + REQUIRE(stmt.FetchRow()); + CHECK(firstName == "Alice"); + CHECK(lastName == "Smith"); + CHECK(salary == 50'000); + + REQUIRE(!stmt.FetchRow()); +} + +TEST_CASE_METHOD(SqlTestFixture, "FetchRow can auto-trim string if requested") +{ + auto stmt = SqlStatement {}; + CreateEmployeesTable(stmt); + stmt.Prepare("INSERT INTO Employees (FirstName, LastName, Salary) VALUES (?, ?, ?)"); + stmt.Execute("Alice ", "Smith ", 50'000); + + SqlTrimmedString firstName { .value = std::string(20, '\0') }; + SqlTrimmedString lastName { .value = std::string(20, '\0') }; + + stmt.ExecuteDirect("SELECT FirstName, LastName FROM Employees"); + stmt.BindOutputColumns(&firstName, &lastName); + + REQUIRE(stmt.FetchRow()); + CHECK(firstName.value == "Alice"); + CHECK(lastName.value == "Smith"); + + REQUIRE(!stmt.FetchRow()); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlStatement.ExecuteBatch") +{ + auto stmt = SqlStatement {}; + + CreateEmployeesTable(stmt); + + stmt.Prepare("INSERT INTO Employees (FirstName, LastName, Salary) VALUES (?, ?, ?)"); + + // Ensure that the batch insert works with different types of containers + // clang-format off + auto const firstNames = std::array { "Alice"sv, "Bob"sv, "Charlie"sv }; // random access STL container (contiguous) + auto const lastNames = std::list { "Smith"sv, "Johnson"sv, "Brown"sv }; // forward access STL container (non-contiguous) + unsigned const salaries[3] = { 50'000, 60'000, 70'000 }; // C-style array + // clang-format on + + stmt.ExecuteBatch(firstNames, lastNames, salaries); + + stmt.ExecuteDirect("SELECT FirstName, LastName, Salary FROM Employees ORDER BY Salary DESC"); + + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == "Charlie"); + REQUIRE(stmt.GetColumn(2) == "Brown"); + REQUIRE(stmt.GetColumn(3) == 70'000); + + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == "Bob"); + REQUIRE(stmt.GetColumn(2) == "Johnson"); + REQUIRE(stmt.GetColumn(3) == 60'000); + + REQUIRE(stmt.FetchRow()); + REQUIRE(stmt.GetColumn(1) == "Alice"); + REQUIRE(stmt.GetColumn(2) == "Smith"); + REQUIRE(stmt.GetColumn(3) == 50'000); + + REQUIRE(!stmt.FetchRow()); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlStatement.ExecuteBatchNative") +{ + auto stmt = SqlStatement {}; + + stmt.ExecuteDirect("CREATE TABLE Test (A VARCHAR(8), B REAL, C INTEGER)"); + + stmt.Prepare("INSERT INTO Test (A, B, C) VALUES (?, ?, ?)"); + + // Ensure that the batch insert works with different types of contiguous containers + auto const first = std::array, 3> { "Hello", "World", "!" }; + auto const second = std::vector { 1.3, 2.3, 3.3 }; + unsigned const third[3] = { 50'000, 60'000, 70'000 }; + + stmt.ExecuteBatchNative(first, second, third); + + stmt.ExecuteDirect("SELECT A, B, C FROM Test ORDER BY C DESC"); + + REQUIRE(stmt.FetchRow()); + CHECK(stmt.GetColumn(1) == "!"); + CHECK_THAT(stmt.GetColumn(2), Catch::Matchers::WithinAbs(3.3, 0.000'001)); + CHECK(stmt.GetColumn(3) == 70'000); + + REQUIRE(stmt.FetchRow()); + CHECK(stmt.GetColumn(1) == "World"); + CHECK_THAT(stmt.GetColumn(2), Catch::Matchers::WithinAbs(2.3, 0.000'001)); + CHECK(stmt.GetColumn(3) == 60'000); + + REQUIRE(stmt.FetchRow()); + CHECK(stmt.GetColumn(1) == "Hello"); + CHECK_THAT(stmt.GetColumn(2), Catch::Matchers::WithinAbs(1.3, 0.000'001)); + CHECK(stmt.GetColumn(3) == 50'000); + + REQUIRE(!stmt.FetchRow()); +} + +TEST_CASE_METHOD(SqlTestFixture, "connection pool reusage", "[sql]") +{ + // auto-instanciating an SqlConnection + auto const id1 = [] { + auto connection = SqlConnection {}; + return connection.ConnectionId(); + }(); + + // Explicitly passing a borrowed SqlConnection + auto const id2 = [] { + auto conn = SqlConnection {}; + auto stmt = SqlStatement { conn }; + return stmt.Connection().ConnectionId(); + }(); + CHECK(id1 == id2); + + // &&-created SqlConnections are reused + auto const id3 = SqlConnection().ConnectionId(); + CHECK(id1 == id3); + + // Explicit constructor passing SqlConnectInfo always creates a new SqlConnection + auto const id4 = SqlConnection(SqlConnection::DefaultConnectInfo()).ConnectionId(); + CHECK(id1 != id4); +} + +struct CustomType +{ + int value; +}; + +template <> +struct SqlDataBinder +{ + static SQLRETURN InputParameter(SQLHSTMT hStmt, SQLUSMALLINT column, CustomType const& value) noexcept + { + return SqlDataBinder::InputParameter(hStmt, column, value.value); + } + + static SQLRETURN OutputColumn(SQLHSTMT hStmt, + SQLUSMALLINT column, + CustomType* result, + SQLLEN* indicator, + SqlDataBinderCallback& callback) noexcept + { + callback.PlanPostProcessOutputColumn([result]() { result->value = PostProcess(result->value); }); + return SqlDataBinder::OutputColumn(hStmt, column, &result->value, indicator, callback); + } + + static SQLRETURN GetColumn(SQLHSTMT hStmt, SQLUSMALLINT column, CustomType* result, SQLLEN* indicator) noexcept + { + return SqlDataBinder::GetColumn(hStmt, column, &result->value, indicator); + } + + static constexpr int PostProcess(int value) noexcept + { + return value |= 0x01; + } +}; + +TEST_CASE_METHOD(SqlTestFixture, "custom types", "[sql]") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("CREATE TABLE Test (Value INT)"); + + // check custom type handling for input parameters + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(CustomType { 42 }); + + // check custom type handling for explicitly fetched output columns + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto result = stmt.GetColumn(1); + REQUIRE(result.value == 42); + + // check custom type handling for bound output columns + result = {}; + stmt.Prepare("SELECT Value FROM Test"); + stmt.BindOutputColumns(&result); + stmt.Execute(); + REQUIRE(stmt.FetchRow()); + REQUIRE(result.value == (42 | 0x01)); +} + +TEST_CASE_METHOD(SqlTestFixture, "LastInsertId") +{ + auto stmt = SqlStatement {}; + CreateEmployeesTable(stmt); + FillEmployeesTable(stmt); + + // 3 because we inserted 3 rows + REQUIRE(stmt.LastInsertId() == 3); +} + +TEST_CASE_METHOD(SqlTestFixture, "SELECT * FROM Table") +{ + auto stmt = SqlStatement {}; + CreateEmployeesTable(stmt); + FillEmployeesTable(stmt); + + stmt.ExecuteDirect("SELECT * FROM Employees"); + + REQUIRE(stmt.NumColumnsAffected() == 4); + + REQUIRE(stmt.FetchRow()); + CHECK(stmt.GetColumn(1) == 1); + CHECK(stmt.GetColumn(2) == "Alice"); + CHECK(stmt.GetColumn(3) == "Smith"); + CHECK(stmt.GetColumn(4) == 50'000); + + REQUIRE(stmt.FetchRow()); + CHECK(stmt.GetColumn(1) == 2); + CHECK(stmt.GetColumn(2) == "Bob"); + CHECK(stmt.GetColumn(3) == "Johnson"); + CHECK(stmt.GetColumn(4) == 60'000); +} + +TEST_CASE_METHOD(SqlTestFixture, "GetColumn in-place store variant") +{ + auto stmt = SqlStatement {}; + CreateEmployeesTable(stmt); + FillEmployeesTable(stmt); + + stmt.ExecuteDirect("SELECT FirstName, LastName, Salary FROM Employees"); + REQUIRE(stmt.FetchRow()); + + CHECK(stmt.GetColumn(1) == "Alice"); + + SqlVariant lastName; + stmt.GetColumn(2, &lastName); + CHECK(std::get(lastName) == "Smith"); + + SqlVariant salary; + stmt.GetColumn(3, &salary); + CHECK(std::get(salary) == 50'000); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlVariant: SqlDate") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("CREATE TABLE Test (Value DATE NOT NULL)"); + + using namespace std::chrono_literals; + auto const expected = SqlDate { 2017y, std::chrono::August, 16d }; + + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(expected); + + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actual = stmt.GetColumn(1); + CHECK(std::get(actual) == expected); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlVariant: SqlTime") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("CREATE TABLE Test (Value TIME NOT NULL)"); + + using namespace std::chrono_literals; + auto const expected = SqlTime { 12h, 34min, 56s }; + + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(expected); + + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actual = stmt.GetColumn(1); + + if (stmt.Connection().ServerType() == SqlServerType::POSTGRESQL) + { + WARN("PostgreSQL seems to report SQL_TYPE_DATE here. Skipping check, that would fail otherwise."); + // TODO: Find out why PostgreSQL reports SQL_TYPE_DATE instead of SQL_TYPE_TIME for SQL column type TIME. + return; + } + + CHECK(std::get(actual) == expected); +} + +static std::string MakeLargeText(size_t size) +{ + auto text = std::string(size, '\0'); + std::generate(text.begin(), text.end(), [i = 0]() mutable { return char('A' + (i++ % 26)); }); + return text; +} + +TEST_CASE_METHOD(SqlTestFixture, "InputParameter and GetColumn for very large values") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("CREATE TABLE Test (Value TEXT)"); + auto const expectedText = MakeLargeText(8 * 1000); + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(expectedText); + + SECTION("check handling for explicitly fetched output columns") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + CHECK(stmt.GetColumn(1) == expectedText); + } + + SECTION("check handling for explicitly fetched output columns (in-place store)") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + std::string actualText; + stmt.GetColumn(1, &actualText); + CHECK(actualText == expectedText); + } + + SECTION("check handling for bound output columns") + { + std::string actualText; // intentionally an empty string, auto-growing behind the scenes + stmt.Prepare("SELECT Value FROM Test"); + stmt.BindOutputColumns(&actualText); + stmt.Execute(); + REQUIRE(stmt.FetchRow()); + REQUIRE(actualText.size() == expectedText.size()); + CHECK(actualText == expectedText); + } +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlDataBinder for SQL type: SqlFixedString") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("CREATE TABLE Test (Value VARCHAR(8) NOT NULL)"); + + auto const expectedValue = SqlFixedString<8> { "Hello " }; + + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(expectedValue); + + SECTION("check custom type handling for explicitly fetched output columns") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actualValue = stmt.GetColumn>(1); + CHECK(actualValue == expectedValue); + + SECTION("Truncated result") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const truncatedValue = stmt.GetColumn>(1); + auto const truncatedStrView = truncatedValue.substr(0); + auto const expectedStrView = expectedValue.substr(0, 3); + CHECK(truncatedStrView == expectedStrView); // "Hel" + } + + SECTION("Trimmed result") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const trimmedValue = stmt.GetColumn>(1); + CHECK(trimmedValue == "Hello"); + } + } + + SECTION("check custom type handling for bound output columns") + { + stmt.Prepare("SELECT Value FROM Test"); + auto actualValue = SqlFixedString<8> {}; + stmt.BindOutputColumns(&actualValue); + stmt.Execute(); + REQUIRE(stmt.FetchRow()); + CHECK(actualValue == expectedValue); + } + + SECTION("check custom type handling for bound output columns (trimmed)") + { + stmt.Prepare("SELECT Value FROM Test"); + auto actualValue = SqlTrimmedFixedString<8> {}; + stmt.BindOutputColumns(&actualValue); + stmt.Execute(); + REQUIRE(stmt.FetchRow()); + CHECK(actualValue == "Hello"); + } +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlDataBinder for SQL type: SqlText") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("CREATE TABLE Test (Value TEXT NOT NULL)"); + + using namespace std::chrono_literals; + auto const expectedValue = SqlText { "Hello, World!" }; + + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(expectedValue); + + SECTION("check custom type handling for explicitly fetched output columns") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actualValue = stmt.GetColumn(1); + CHECK(actualValue == expectedValue); + } + + SECTION("check custom type handling for bound output columns") + { + stmt.Prepare("SELECT Value FROM Test"); + auto actualValue = SqlText {}; + stmt.BindOutputColumns(&actualValue); + stmt.Execute(); + REQUIRE(stmt.FetchRow()); + CHECK(actualValue == expectedValue); + } +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlDataBinder for SQL type: SqlDateTime") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect(std::format("CREATE TABLE Test (Value {} NOT NULL)", + stmt.Connection().Traits().ColumnTypeName(SqlColumnType::DATETIME))); + + // With SQL Server or Oracle, we could use DATETIME2(7) and have nano-second precision (with 100ns resolution) + // The standard DATETIME and ODBC SQL_TIMESTAMP have only millisecond precision. + + using namespace std::chrono_literals; + auto const expectedValue = SqlDateTime(2017y, std::chrono::August, 16d, 17h, 30min, 45s, 123'000'000ns); + + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(expectedValue); + + SECTION("check custom type handling for explicitly fetched output columns") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actualValue = stmt.GetColumn(1); + CHECK(actualValue == expectedValue); + } + + SECTION("check custom type handling for bound output columns") + { + stmt.Prepare("SELECT Value FROM Test"); + auto actualValue = SqlDateTime {}; + stmt.BindOutputColumns(&actualValue); + stmt.Execute(); + REQUIRE(stmt.FetchRow()); + CHECK(actualValue == expectedValue); + } +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlDataBinder for SQL type: date") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("CREATE TABLE Test (Value DATE NOT NULL)"); + using namespace std::chrono_literals; + auto const expected = SqlDate { std::chrono::year_month_day { 2017y, std::chrono::August, 16d } }; + + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(expected); + + SECTION("check custom type handling for explicitly fetched output columns") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actual = stmt.GetColumn(1); + REQUIRE(actual == expected); + } + + SECTION("check custom type handling for explicitly fetched output columns") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actual = stmt.GetColumn(1); + REQUIRE(actual == expected); + } + + SECTION("check custom type handling for bound output columns") + { + stmt.Prepare("SELECT Value FROM Test"); + auto actual = SqlDate {}; + stmt.BindOutputColumns(&actual); + stmt.Execute(); + REQUIRE(stmt.FetchRow()); + REQUIRE(actual == expected); + } +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlDataBinder for SQL type: time") +{ + auto stmt = SqlStatement {}; + stmt.ExecuteDirect("CREATE TABLE Test (Value TIME NOT NULL)"); + using namespace std::chrono_literals; + auto const expected = SqlTime(12h, 34min, 56s); + + stmt.Prepare("INSERT INTO Test (Value) VALUES (?)"); + stmt.Execute(expected); + + SECTION("check custom type handling for explicitly fetched output columns") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actual = stmt.GetColumn(1); + REQUIRE(actual == expected); + } + + SECTION("check custom type handling for explicitly fetched output columns") + { + stmt.ExecuteDirect("SELECT Value FROM Test"); + REQUIRE(stmt.FetchRow()); + auto const actual = stmt.GetColumn(1); + REQUIRE(actual == expected); + } + + SECTION("check custom type handling for bound output columns") + { + stmt.Prepare("SELECT Value FROM Test"); + auto actual = SqlTime {}; + stmt.BindOutputColumns(&actual); + stmt.Execute(); + REQUIRE(stmt.FetchRow()); + REQUIRE(actual == expected); + } +} + +struct QueryExpectations +{ + std::string_view sqlite; + std::string_view sqlServer; +}; + +void checkSqlQueryBuilder(SqlComposedQuery const& sqlQuery, + QueryExpectations const& expectations, + std::source_location const& location = std::source_location::current()) +{ + INFO(std::format("Test source location: {}:{}", location.file_name(), location.line())); + + auto const& sqliteFormatter = SqlQueryFormatter::Sqlite(); + auto const& sqlServerFormatter = SqlQueryFormatter::SqlServer(); + + CHECK(sqlQuery.ToSql(sqliteFormatter) == expectations.sqlite); + CHECK(sqlQuery.ToSql(sqlServerFormatter) == expectations.sqlServer); +}; + +TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Count") +{ + checkSqlQueryBuilder(SqlQueryBuilder::From("Table").Count(), + QueryExpectations { + .sqlite = "SELECT COUNT(*) FROM \"Table\"", + .sqlServer = "SELECT COUNT(*) FROM \"Table\"", + }); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.All") +{ + checkSqlQueryBuilder(SqlQueryBuilder::From("That").Select("a", "b").Select("c").GroupBy("a").OrderBy("b").All(), + QueryExpectations { + .sqlite = "SELECT \"a\", \"b\", \"c\" FROM \"That\" GROUP BY \"a\" ORDER BY \"b\" ASC", + .sqlServer = "SELECT \"a\", \"b\", \"c\" FROM \"That\" GROUP BY \"a\" ORDER BY \"b\" ASC", + }); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.First") +{ + checkSqlQueryBuilder(SqlQueryBuilder::From("That").Select("field1").OrderBy("id").First(), + QueryExpectations { + .sqlite = "SELECT \"field1\" FROM \"That\" ORDER BY \"id\" ASC LIMIT 1", + .sqlServer = "SELECT TOP 1 \"field1\" FROM \"That\" ORDER BY \"id\" ASC", + }); +} + +TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Range") +{ + checkSqlQueryBuilder( + SqlQueryBuilder::From("That").Select("foo", "bar").OrderBy("id").Range(200, 50), + QueryExpectations { + .sqlite = "SELECT \"foo\", \"bar\" FROM \"That\" ORDER BY \"id\" ASC LIMIT 50 OFFSET 200", + .sqlServer = + "SELECT \"foo\", \"bar\" FROM \"That\" ORDER BY \"id\" ASC OFFSET 200 ROWS FETCH NEXT 50 ROWS ONLY", + }); +} diff --git a/src/tests/ModelAssociationsTests.cpp b/src/tests/ModelAssociationsTests.cpp new file mode 100644 index 00000000..3d36005d --- /dev/null +++ b/src/tests/ModelAssociationsTests.cpp @@ -0,0 +1,426 @@ +#include "../Lightweight/Model/All.hpp" +#include "../Lightweight/SqlConnection.hpp" +#include "Utils.hpp" + +#include +#include + +struct Artist; +struct Track; +struct Publisher; + +struct Artist: Model::Record +{ + Model::Field name; + Model::HasMany tracks; + + Artist(): + Record { "artists" }, + name { *this }, + tracks { *this } + { + } + + Artist(Artist&& other) noexcept: + Record { std::move(other) }, + name { *this, std::move(other.name) }, + tracks { *this, std::move(other.tracks) } + { + } +}; + +struct Track: Model::Record +{ + Model::Field title; + Model::BelongsTo artist; + + Track(): + Record { "tracks" }, + title { *this }, + artist { *this } + { + } + + Track(Track&& other) noexcept: + Record { std::move(other) }, + title { *this, std::move(other.title) }, + artist { *this, std::move(other.artist) } + { + } +}; + +TEST_CASE_METHOD(SqlTestFixture, "Model.BelongsTo", "[model]") +{ + CreateModelTable(); + CreateModelTable(); + + Artist artist; + artist.name = "Snoop Dog"; + artist.Save(); + REQUIRE(artist.Id().value); + + Track track1; + track1.title = "Wuff"; + track1.artist = artist; // track1 "BelongsTo" artist + track1.Save(); + REQUIRE(track1.Id().value); + + CHECK(track1.artist->Inspect() == artist.Inspect()); + + artist.Destroy(); + CHECK(Artist::Count() == 0); + CHECK(Track::Count() == 0); + // Destroying the artist must also destroy the track, due to the foreign key constraint. +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.HasMany", "[model]") +{ + CreateModelTable(); + CreateModelTable(); + + Artist artist; + artist.name = "Snoop Dog"; + artist.Save(); + + Track track1; + track1.title = "Wuff"; + track1.artist = artist; + track1.Save(); + + Track track2; + track2.title = "Paff Dog"; + track2.artist = artist; + track2.Save(); + + REQUIRE(artist.tracks.IsLoaded() == false); + REQUIRE(artist.tracks.IsEmpty() == false); + REQUIRE(artist.tracks.Count() == 2); + artist.tracks.Load(); + REQUIRE(artist.tracks.IsLoaded() == true); + REQUIRE(artist.tracks.Count() == 2); // Using cached value + REQUIRE(artist.tracks[0].Inspect() == track1.Inspect()); + REQUIRE(artist.tracks[1].Inspect() == track2.Inspect()); +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.HasOne", "[model]") +{ + struct Suppliers; + struct Account; + + struct Suppliers: Model::Record + { + Model::Field name; + Model::HasOne account; + + Suppliers(): + Record { "suppliers" }, + name { *this }, + account { *this } + { + } + + Suppliers(Suppliers&& other) noexcept: + Record { std::move(other) }, + name { *this, std::move(other.name) }, + account { *this, std::move(other.account) } + { + } + }; + + struct Account: Model::Record + { + Model::Field iban; + Model::BelongsTo supplier; + + Account(): + Record { "accounts" }, + iban { *this }, + supplier { *this } + { + } + + Account(Account&& other) noexcept: + Record { std::move(other) }, + iban { *this, std::move(other.iban) }, + supplier { *this, std::move(other.supplier) } + { + } + }; + + CreateModelTable(); + CreateModelTable(); + + Suppliers supplier; + supplier.name = "Supplier"; + supplier.Save(); + + Account account; + account.iban = "DE123456789"; + account.supplier = supplier; + account.Save(); + + REQUIRE(supplier.account.IsLoaded() == false); + supplier.account.Load(); + REQUIRE(supplier.account.IsLoaded() == true); + REQUIRE(supplier.account->Inspect() == account.Inspect()); +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.HasOneThrough", "[model]") +{ + // {{{ models + struct Suppliers; + struct Account; + struct AccountHistory; + + struct Suppliers: Model::Record + { + Model::HasOne account; + Model::HasOneThrough accountHistory; + Model::Field name; + + // {{{ ctors + Suppliers(): + Record { "suppliers" }, + account { *this }, + accountHistory { *this }, + name { *this } + { + } + + Suppliers(Suppliers&& other) noexcept: + Record { std::move(other) }, + account { *this, std::move(other.account) }, + accountHistory { *this, std::move(other.accountHistory) }, + name { *this, std::move(other.name) } + { + } + // }}} + }; + + struct Account: Model::Record + { + Model::Field iban; + Model::BelongsTo supplier; + Model::HasOne accountHistory; + + // {{{ ctors + Account(): + Record { "accounts" }, + iban { *this }, + supplier { *this }, + accountHistory { *this } + { + } + + Account(Account&& other) noexcept: + Record { std::move(other) }, + iban { *this, std::move(other.iban) }, + supplier { *this, std::move(other.supplier) }, + accountHistory { *this, std::move(other.accountHistory) } + { + } + // }}} + }; + + struct AccountHistory: Model::Record + { + Model::BelongsTo account; + Model::Field description; + + // {{{ ctors + AccountHistory(): + Record { "account_histories" }, + account { *this }, + description { *this } + { + } + + AccountHistory(AccountHistory&& other) noexcept: + Record { std::move(other) }, + account { *this, std::move(other.account) }, + description { *this, std::move(other.description) } + { + } + // }}} + }; + // }}} + + CreateModelTable(); + CreateModelTable(); + CreateModelTable(); + + Suppliers supplier; + supplier.name = "The Supplier"; + supplier.Save(); + + Account account; + account.supplier = supplier; + account.iban = "DE123456789"; + account.Save(); + + AccountHistory accountHistory; + accountHistory.account = account; + accountHistory.description = "Initial deposit"; + accountHistory.Save(); + + REQUIRE(supplier.accountHistory.IsLoaded() == false); + REQUIRE(supplier.accountHistory->Inspect() == accountHistory.Inspect()); // auto-loads the accountHistory + REQUIRE(supplier.accountHistory.IsLoaded() == true); +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.HasManyThrough", "[model]") +{ + // {{{ Models + struct Physician; + struct Appointment; + struct Patient; + + struct Physician: Model::Record + { + Model::Field name; + Model::HasMany appointments; + Model::HasManyThrough patients; + + Physician(): + Record { "physicians" }, + name { *this }, + appointments { *this }, + patients { *this } + { + } + + Physician(Physician&& other) noexcept: + Record { std::move(other) }, + name { *this, std::move(other.name) }, + appointments { *this, std::move(other.appointments) }, + patients { *this, std::move(other.patients) } + { + } + }; + + struct Appointment: Model::Record + { + Model::Field date; + Model::Field comment; + Model::BelongsTo physician; + Model::BelongsTo patient; + + Appointment(): + Record { "appointments" }, + date { *this }, + comment { *this }, + physician { *this }, + patient { *this } + { + } + + Appointment(Appointment&& other) noexcept: + Record { std::move(other) }, + date { *this, std::move(other.date) }, + comment { *this, std::move(other.comment) }, + physician { *this, std::move(other.physician) }, + patient { *this, std::move(other.patient) } + { + } + }; + + struct Patient: Model::Record + { + Model::Field name; + Model::Field comment; + Model::HasMany appointments; + Model::HasManyThrough physicians; + + Patient(): + Record { "patients" }, + name { *this }, + comment { *this }, + appointments { *this }, + physicians { *this } + { + } + + Patient(Patient&& other) noexcept: + Record { std::move(other) }, + name { *this, std::move(other.name) }, + comment { *this, std::move(other.comment) }, + appointments { *this, std::move(other.appointments) }, + physicians { *this, std::move(other.physicians) } + { + } + }; + // }}} + + CreateModelTable(); + CreateModelTable(); + CreateModelTable(); + + Physician physician1; + physician1.name = "Dr. House"; + physician1.Save(); + + Physician physician2; + physician2.name = "Granny"; + physician2.Save(); + + Patient patient1; + patient1.name = "Blooper"; + patient1.comment = "Prefers morning times"; + patient1.Save(); + + Patient patient2; + patient2.name = "Valentine"; + patient2.comment = "always friendly"; + patient2.Save(); + + Appointment patient1Apointment1; + patient1Apointment1.date = SqlDateTime::Now(); + patient1Apointment1.patient = patient1; + patient1Apointment1.physician = physician2; + patient1Apointment1.comment = "Patient is a bit nervous"; + patient1Apointment1.Save(); + + Appointment patient1Apointment2; + patient1Apointment2.date = SqlDateTime::Now(); + patient1Apointment2.patient = patient1; + patient1Apointment2.physician = physician1; + patient1Apointment2.comment = "Patient is a bit nervous, again"; + patient1Apointment2.Save(); + + Appointment patient2Apointment1; + patient2Apointment1.date = SqlDateTime::Now(); + patient2Apointment1.patient = patient2; + patient2Apointment1.physician = physician1; + patient2Apointment1.comment = "Patient is funny"; + patient2Apointment1.Save(); + + auto const queriedCount = physician1.patients.Count(); + CHECK(queriedCount == 2); + + auto const& physician1Patients = physician1.patients.All(); + REQUIRE(physician1Patients.size() == 2); + CHECK(physician1Patients.at(0).Inspect() == patient1.Inspect()); + CHECK(physician1Patients.at(1).Inspect() == patient2.Inspect()); + + CHECK(patient1.physicians.Count() == 2); + CHECK(patient2.physicians.Count() == 1); + + // Test Each() method + size_t numPatientsIterated = 0; + std::vector retrievedPatients; + physician2.patients.Each([&](Patient& patient) { + // NB: Mind, SQLite does not seem to like issuing another query on the memory database while + // we are currently fetching the results via the Each() call. So we moved the results to a + // vector and then inspect them. + REQUIRE(numPatientsIterated == 0); + ++numPatientsIterated; + retrievedPatients.emplace_back(std::move(patient)); + }); + Patient const& patient = retrievedPatients.at(0); + CHECK(patient.Inspect() == patient1.Inspect()); // Blooper + CHECK(patient.comment.Value() == "Prefers morning times"); + CHECK(patient.physicians.Count() == 2); + CHECK(patient.physicians.IsLoaded() == false); + CHECK(patient.physicians[0].name.Value() == "Granny"); + CHECK(patient.physicians[0].Inspect() == physician2.Inspect()); +} diff --git a/src/tests/ModelTests.cpp b/src/tests/ModelTests.cpp new file mode 100644 index 00000000..e7f397f0 --- /dev/null +++ b/src/tests/ModelTests.cpp @@ -0,0 +1,353 @@ +#include "Utils.hpp" + +#include +#include + +#include +#include + +using namespace std::string_view_literals; + +struct Author; +struct Book; +struct Company; +struct Job; +struct Person; +struct Phone; + +TEST_CASE_METHOD(SqlTestFixture, "Model.Move", "[model]") +{ + struct MovableRecord: public Model::Record + { + Model::Field name; + + MovableRecord(): + Record { "movables" }, + name { *this } + { + } + + MovableRecord(MovableRecord&& other) noexcept: + Record { std::move(other) }, + name { *this, std::move(other.name) } + { + } + }; + + // Ensure move constructor is working as expected. + // Inspect() touches the most internal data structures, so we use this call to verify. + + CreateModelTable(); + + MovableRecord record; + record.name = "Foxy Fox"; + record.Save(); + auto const originalText = record.Inspect(); + INFO("Original: " << originalText); + + MovableRecord movedRecord(std::move(record)); + auto const movedText = movedRecord.Inspect(); + REQUIRE(movedText == originalText); +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.Field: SqlTrimmedString", "[model]") +{ + struct TrimmedStringRecord: Model::Record + { + Model::Field name; + + TrimmedStringRecord(): + Record { "trimmed_strings" }, + name { *this } + { + } + + TrimmedStringRecord(TrimmedStringRecord&& other) noexcept: + Record { std::move(other) }, + name { *this, std::move(other.name) } + { + } + }; + + CreateModelTable(); + + TrimmedStringRecord record; + record.name = SqlTrimmedString { " Hello, World! " }; + record.Save(); + record.Reload(); // Ensure we fetch name from the database and got trimmed on fetch. + + CHECK(record.name == SqlTrimmedString { " Hello, World!" }); +} + +struct Author: Model::Record +{ + Model::Field name; + Model::HasMany books; + + Author(): + Record { "authors" }, + name { *this }, + books { *this } + { + } + + Author(Author&& other) noexcept: + Record { std::move(other) }, + name { *this, std::move(other.name) }, + books { *this, std::move(other.books) } + { + } +}; + +struct Book: Model::Record +{ + Model::Field, 2, "title"> title; + Model::Field isbn; + Model::BelongsTo author; + + Book(): + Record { "books", "id" }, + title { *this }, + isbn { *this }, + author { *this } + { + } + + Book(Book&& other) noexcept: + Record { std::move(other) }, + title { *this, std::move(other.title) }, + isbn { *this, std::move(other.isbn) }, + author { *this, std::move(other.author) } + { + } +}; + +TEST_CASE_METHOD(SqlTestFixture, "Model.Create", "[model]") +{ + CreateModelTable(); + CreateModelTable(); + + Author author; + author.name = "Bjarne Stroustrup"; + author.Save(); + REQUIRE(author.Id() == 1); + REQUIRE(author.books.Count() == 0); + + Book book1; + book1.title = "The C++ Programming Language"; + book1.isbn = "978-0-321-56384-2"; + book1.author = author; + book1.Save(); + REQUIRE(book1.Id() == 1); + REQUIRE(Book::Count() == 1); + REQUIRE(author.books.Count() == 1); + + Book book2; + book2.title = "A Tour of C++"; + book2.isbn = "978-0-321-958310"; + book2.author = author; + book2.Save(); + REQUIRE(book2.Id() == 2); + REQUIRE(Book::Count() == 2); + REQUIRE(author.books.Count() == 2); + + // Also take the chance to ensure the formatter works. + REQUIRE(std::format("{}", author) == author.Inspect()); +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.Load", "[model]") +{ + Model::CreateSqlTables(); + + Author author; + author.name = "Bjarne Stroustrup"; + author.Save(); + + Book book; + book.title = "The C++ Programming Language"; + book.isbn = "978-0-321-56384-2"; + book.author = author; + book.Save(); + + Book bookLoaded; + bookLoaded.Load(book.Id()); + INFO("Book: " << book); + CHECK(bookLoaded.Id() == book.Id()); + CHECK(bookLoaded.title == book.title); + CHECK(bookLoaded.isbn == book.isbn); + CHECK(bookLoaded.author == book.author); +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.Find", "[model]") +{ + Model::CreateSqlTables(); + + Author author; + author.name = "Bjarne Stroustrup"; + author.Save(); + + Book book; + book.title = "The C++ Programming Language"; + book.isbn = "978-0-321-56384-2"; + book.author = author; + book.Save(); + + Book bookLoaded = Book::Find(book.Id()).value(); + INFO("Book: " << book); + CHECK(bookLoaded.Id() == book.Id()); // primary key + CHECK(bookLoaded.title == book.title); // Field<> + CHECK(bookLoaded.isbn == book.isbn); // Field<> + CHECK(bookLoaded.author == book.author); // BelongsTo<> +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.Update", "[model]") +{ + Model::CreateSqlTables(); + + Author author; + author.name = "Bjarne Stroustrup"; + author.Save(); + + Book book; + book.title = "The C++ Programming Language"; + book.isbn = "978-0-321-56384-2"; + book.author = author; + book.Save(); + + book.isbn = "978-0-321-958310"; + book.Save(); + + Book bookRead = Book::Find(book.Id()).value(); + CHECK(bookRead.Id() == book.Id()); + CHECK(bookRead.title == book.title); + CHECK(bookRead.isbn == book.isbn); +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.Destroy", "[model]") +{ + CreateModelTable(); + + Author author1; + author1.name = "Bjarne Stroustrup"; + author1.Save(); + REQUIRE(Author::Count() == 1); + + Author author2; + author2.name = "John Doe"; + author2.Save(); + REQUIRE(Author::Count() == 2); + + author1.Destroy(); + REQUIRE(Author::Count() == 1); +} + +TEST_CASE_METHOD(SqlTestFixture, "Model.All", "[model]") +{ + CreateModelTable(); + + Author author1; + author1.name = "Bjarne Stroustrup"; + author1.Save(); + + Author author2; + author2.name = "John Doe"; + author2.Save(); + + Author author3; + author3.name = "Some very long name"; + author3.Save(); + + Author author4; + author4.name = "Shorty"; + author4.Save(); + + auto authors = Author::All(); + REQUIRE(authors.size() == 4); + CHECK(authors[0].name == author1.name); + CHECK(authors[1].name == author2.name); + CHECK(authors[2].name == author3.name); + CHECK(authors[3].name == author4.name); +} + +struct ColumnTypesRecord: Model::Record +{ + Model::Field stringColumn; + Model::Field textColumn; + + ColumnTypesRecord(): + Record { "column_types" }, + stringColumn { *this }, + textColumn { *this } + { + } + + ColumnTypesRecord(ColumnTypesRecord&& other) noexcept: + Record { std::move(other) }, + stringColumn { *this, std::move(other.stringColumn) }, + textColumn { *this, std::move(other.textColumn) } + { + } +}; + +TEST_CASE_METHOD(SqlTestFixture, "Model.ColumnTypes", "[model]") +{ + CreateModelTable(); + + ColumnTypesRecord record; + record.stringColumn = "Hello"; + record.textColumn = SqlText { ", World!" }; + record.Save(); + + ColumnTypesRecord record2 = ColumnTypesRecord::Find(record.Id()).value(); + CHECK(record2.stringColumn == record.stringColumn); + CHECK(record2.textColumn == record.textColumn); +} + +struct Employee: Model::Record +{ + Model::Field name; + Model::Field isSenior; + + Employee(): + Record { "employees" }, + name { *this }, + isSenior { *this } + { + } + + Employee(Employee&& other) noexcept: + Record { std::move(other) }, + name { *this, std::move(other.name) }, + isSenior { *this, std::move(other.isSenior) } + { + } +}; + +TEST_CASE_METHOD(SqlTestFixture, "Model.Where", "[model]") +{ + CreateModelTable(); + + Employee employee1; + employee1.name = "John Doe"; + employee1.isSenior = false; + employee1.Save(); + + Employee employee2; + employee2.name = "Jane Doe"; + employee2.isSenior = true; + employee2.Save(); + + Employee employee3; + employee3.name = "John Smith"; + employee3.isSenior = true; + employee3.Save(); + + auto employees = Employee::Where("is_senior"sv, true).All(); + for (const auto& employee: employees) + INFO("Employee: {}" << employee); // FIXME: breaks due to field name being NULL + REQUIRE(employees.size() == 2); + CHECK(employees[0].Id() == employee2.Id()); + CHECK(employees[0].name == employee2.name); + CHECK(employees[1].Id() == employee3.Id()); + CHECK(employees[1].name == employee3.name); +} diff --git a/src/tests/Utils.hpp b/src/tests/Utils.hpp new file mode 100644 index 00000000..cec99304 --- /dev/null +++ b/src/tests/Utils.hpp @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#if defined(_WIN32) || defined(_WIN64) + #include +#endif + +#include "../Lightweight/Model/All.hpp" +#include "../Lightweight/SqlConnectInfo.hpp" +#include "../Lightweight/SqlDataBinder.hpp" +#include "../Lightweight/SqlLogger.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// Refer to an in-memory SQLite database (and assuming the sqliteodbc driver is installed) +// See: +// - https://www.sqlite.org/inmemorydb.html +// - http://www.ch-werner.de/sqliteodbc/ +// - https://github.com/softace/sqliteodbc +// +auto const inline DefaultTestConnectionString = SqlConnectionString { + .value = std::format("DRIVER={};Database={}", +#if defined(_WIN32) || defined(_WIN64) + "SQLite3 ODBC Driver", +#else + "SQLite3", +#endif + "file::memory:"), +}; + +class ScopedSqlNullLogger: public SqlLogger +{ + private: + SqlLogger& m_previousLogger = SqlLogger::GetLogger(); + + public: + ScopedSqlNullLogger() + { + SqlLogger::SetLogger(*this); + } + + ~ScopedSqlNullLogger() override + { + SqlLogger::SetLogger(m_previousLogger); + } + + void OnWarning(std::string_view const&) override {} + void OnError(SqlError, SqlErrorInfo const&, std::source_location) override {} + void OnConnectionOpened(SqlConnection const&) override {} + void OnConnectionClosed(SqlConnection const&) override {} + void OnConnectionIdle(SqlConnection const&) override {} + void OnConnectionReuse(SqlConnection const&) override {} + void OnExecuteDirect(std::string_view const&) override {} + void OnPrepare(std::string_view const&) override {} + void OnExecute() override {} + void OnExecuteBatch() override {} + void OnFetchedRow() override {} + void OnStats(SqlConnectionStats const&) override {} +}; + +class SqlTestFixture +{ + public: + static inline std::string_view const testDatabaseName = "LightweightTest"; + + using MainProgramArgs = std::tuple; + + static std::expected Initialize(int argc, char** argv) + { + using namespace std::string_view_literals; + int i = 1; + for (; i < argc; ++i) + { + if (argv[i] == "--trace-sql"sv) + SqlLogger::SetLogger(SqlLogger::TraceLogger()); + else if (argv[i] == "--trace-model"sv) + Model::QueryLogger::Set(Model::QueryLogger::StandardLogger()); + else if (argv[i] == "--help"sv || argv[i] == "-h"sv) + { + std::println("{} [--trace-sql] [--trace-model] [[--] [Catch2 flags ...]]", argv[0]); + return std::unexpected { EXIT_SUCCESS }; + } + else if (argv[i] == "--"sv) + { + ++i; + break; + } + else + break; + } + + if (i < argc) + argv[i - 1] = argv[0]; + + if (auto const* s = std::getenv("ODBC_CONNECTION_STRING"); s && *s) + { + std::println("Using ODBC connection string: '{}'", SanitizePwd(s)); + SqlConnection::SetDefaultConnectInfo(SqlConnectionString { s }); + } + else + { + // Use an in-memory SQLite3 database by default (for testing purposes) + std::println("Using default ODBC connection string: '{}'", DefaultTestConnectionString.value); + SqlConnection::SetDefaultConnectInfo(DefaultTestConnectionString); + } + + auto sqlConnection = SqlConnection(); + std::println("Running test cases against: {} ({}) (identified as: {})", + sqlConnection.ServerName(), + sqlConnection.ServerVersion(), + sqlConnection.ServerType()); + + SqlConnection::SetPostConnectedHook(&SqlTestFixture::PostConnectedHook); + + return MainProgramArgs { argc - (i - 1), argv + (i - 1) }; + } + + static void PostConnectedHook(SqlConnection& connection) + { + switch (connection.ServerType()) + { + case SqlServerType::SQLITE: { + auto stmt = SqlStatement { connection }; + // Enable foreign key constraints for SQLite + (void) stmt.ExecuteDirect("PRAGMA foreign_keys = ON"); + break; + } + case SqlServerType::MICROSOFT_SQL: + case SqlServerType::POSTGRESQL: + case SqlServerType::ORACLE: + case SqlServerType::MYSQL: + case SqlServerType::UNKNOWN: + break; + } + } + + SqlTestFixture() + { + REQUIRE(SqlConnection().IsAlive()); + DropAllTablesInDatabase(); + SqlConnection::KillAllIdle(); + } + + virtual ~SqlTestFixture() + { + SqlConnection::KillAllIdle(); + } + + template + void CreateModelTable() + { + auto const tableName = T().TableName(); + m_createdTables.emplace_back(tableName); + T::CreateTable(); + } + + private: + static std::string SanitizePwd(std::string_view input) + { + std::regex const pwdRegex { + R"(PWD=.*?;)", + std::regex_constants::ECMAScript | std::regex_constants::icase, + }; + std::stringstream outputString; + std::regex_replace( + std::ostreambuf_iterator { outputString }, input.begin(), input.end(), pwdRegex, "Pwd=***;"); + return outputString.str(); + } + + std::vector GetAllTableNames() + { + auto result = std::vector(); + auto stmt = SqlStatement(); + auto const sqlResult = SQLTables(stmt.NativeHandle(), + (SQLCHAR*) testDatabaseName.data(), + (SQLSMALLINT) testDatabaseName.size(), + nullptr, + 0, + nullptr, + 0, + (SQLCHAR*) "TABLE", + SQL_NTS); + if (SQL_SUCCEEDED(sqlResult)) + { + while (stmt.FetchRow()) + { + result.emplace_back(stmt.GetColumn(3)); // table name + } + } + return result; + } + + void DropAllTablesInDatabase() + { + auto stmt = SqlStatement {}; + + switch (stmt.Connection().ServerType()) + { + case SqlServerType::MICROSOFT_SQL: + SqlConnection::KillAllIdle(); + (void) stmt.ExecuteDirect(std::format("USE {}", "master")); + (void) stmt.ExecuteDirect(std::format("DROP DATABASE IF EXISTS \"{}\"", testDatabaseName)); + (void) stmt.ExecuteDirect(std::format("CREATE DATABASE \"{}\"", testDatabaseName)); + (void) stmt.ExecuteDirect(std::format("USE {}", testDatabaseName)); + break; + case SqlServerType::POSTGRESQL: + if (m_createdTables.empty()) + m_createdTables = GetAllTableNames(); + for (auto i = m_createdTables.rbegin(); i != m_createdTables.rend(); ++i) + (void) stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS \"{}\" CASCADE", *i)); + break; + default: + for (auto i = m_createdTables.rbegin(); i != m_createdTables.rend(); ++i) + (void) stmt.ExecuteDirect(std::format("DROP TABLE IF EXISTS \"{}\"", *i)); + break; + } + m_createdTables.clear(); + } + + std::vector m_createdTables; +}; + +// {{{ ostream support for Lightweight, for debugging purposes +inline std::ostream& operator<<(std::ostream& os, Model::RecordId value) +{ + return os << "ModelId { " << value.value << " }"; +} + +inline std::ostream& operator<<(std::ostream& os, Model::AbstractRecord const& value) +{ + return os << std::format("{}", value); +} + +inline std::ostream& operator<<(std::ostream& os, SqlResult const& result) +{ + if (result) + return os << "SqlResult { success }"; + return os << "SqlResult { error: " << result.error() << " }"; +} + +inline std::ostream& operator<<(std::ostream& os, SqlTrimmedString const& value) +{ + return os << std::format("SqlTrimmedString {{ '{}' }}", value); +} + +inline std::ostream& operator<<(std::ostream& os, SqlDate const& date) +{ + auto const ymd = date.value(); + return os << std::format("SqlDate {{ {}-{}-{} }}", ymd.year(), ymd.month(), ymd.day()); +} + +inline std::ostream& operator<<(std::ostream& os, SqlTime const& time) +{ + auto const value = time.value(); + return os << std::format("SqlTime {{ {:02}:{:02}:{:02}.{:06} }}", + value.hours().count(), + value.minutes().count(), + value.seconds().count(), + value.subseconds().count()); +} + +inline std::ostream& operator<<(std::ostream& os, SqlDateTime const& datetime) +{ + auto const value = datetime.value(); + auto const totalDays = std::chrono::floor(value); + auto const ymd = std::chrono::year_month_day { totalDays }; + auto const hms = std::chrono::hh_mm_ss { std::chrono::floor( + value - totalDays) }; + return os << std::format("SqlDateTime {{ {:04}-{:02}-{:02} {:02}:{:02}:{:02}.{:09} }}", + (int) ymd.year(), + (unsigned) ymd.month(), + (unsigned) ymd.day(), + hms.hours().count(), + hms.minutes().count(), + hms.seconds().count(), + hms.subseconds().count()); +} + +template +inline std::ostream& operator<<(std::ostream& os, SqlResult const& result) +{ + if (result) + return os << "SqlResult { value: " << result.value() << " }"; + return os << "SqlResult { error: " << result.error() << " }"; +} + +template +inline std::ostream& operator<<(std::ostream& os, SqlFixedString const& value) +{ + if constexpr (PostOp == SqlStringPostRetrieveOperation::NOTHING) + return os << std::format("SqlFixedString<{}> {{ '{}' }}", N, value.data()); + if constexpr (PostOp == SqlStringPostRetrieveOperation::TRIM_RIGHT) + return os << std::format("SqlTrimmedFixedString<{}> {{ '{}' }}", N, value.data()); +} + +template +inline std::ostream& operator<<(std::ostream& os, + Model::Field const& field) +{ + return os << std::format("Field<{}:{}: {}>", TheTableColumnIndex, TheColumnName.value, field.Value()); +} + +// }}} diff --git a/src/tools/CMakeLists.txt b/src/tools/CMakeLists.txt new file mode 100644 index 00000000..0651df4b --- /dev/null +++ b/src/tools/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(ddl2cpp ddl2cpp.cpp) +target_link_libraries(ddl2cpp PRIVATE Lightweight::Lightweight) +target_compile_features(ddl2cpp PUBLIC cxx_std_23) diff --git a/src/tools/ddl2cpp.cpp b/src/tools/ddl2cpp.cpp new file mode 100644 index 00000000..d82b5566 --- /dev/null +++ b/src/tools/ddl2cpp.cpp @@ -0,0 +1,397 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +// TODO: have an OdbcConnectionString API to help compose/decompose connection settings +// TODO: move SanitizePwd function into that API, like `string OdbcConnectionString::PrettyPrintSanitized()` + +// TODO: get inspired by .NET's Dapper, and EF Core APIs + +namespace +{ + +constexpr auto finally(auto&& cleanupRoutine) noexcept +{ + struct Finally + { + std::remove_cvref_t cleanup; + ~Finally() + { + cleanup(); + } + }; + return Finally { std::forward(cleanupRoutine) }; +} + +std::string MakeType(SqlSchema::Column const& column) +{ + using ColumnType = SqlColumnType; + switch (column.type) + { + case ColumnType::CHAR: + if (column.size == 1) + return "char"; + else + return std::format("SqlTrimmedString<{}>", column.size); + case ColumnType::STRING: + if (column.size == 1) + return "char"; + else + return std::format("std::string", column.size); + case ColumnType::TEXT: + return std::format("SqlText"); + case ColumnType::BOOLEAN: + return "bool"; + case ColumnType::INTEGER: + return "int"; + case ColumnType::REAL: + return "double"; + case ColumnType::BLOB: + return "std::vector"; + case ColumnType::DATE: + return "SqlDate"; + case ColumnType::TIME: + return "SqlTime"; + case ColumnType::DATETIME: + return "SqlDateTime"; + case ColumnType::UNKNOWN: + break; + } + return "void"; +} + +std::string MakeVariableName(SqlSchema::FullyQualifiedTableName const& table) +{ + auto name = std::format("{}", table.table); + name.at(0) = std::tolower(name.at(0)); + return name; +} + +constexpr bool isVowel(char c) noexcept +{ + switch (c) + { + case 'a': + case 'e': + case 'i': + case 'o': + case 'u': + return true; + default: + return false; + } +} + +std::string MakePluralVariableName(SqlSchema::FullyQualifiedTableName const& table) +{ + auto const& sqlName = table.table; + if (sqlName.back() == 'y' && sqlName.size() > 1 && !isVowel(sqlName.at(sqlName.size() - 2))) + { + auto name = std::format("{}ies", sqlName.substr(0, sqlName.size() - 1)); + name.at(0) = std::tolower(name.at(0)); + return name; + } + auto name = std::format("{}s", sqlName); + name.at(0) = std::tolower(name.at(0)); + return name; +} + +class CxxModelPrinter +{ + private: + mutable std::vector m_forwwardDeclarations; + std::stringstream m_definitions; + + public: + std::string str(std::string_view modelNamespace) const + { + std::ranges::sort(m_forwwardDeclarations); + + std::stringstream output; + output << "#include \"src/Lightweight/Model/All.hpp\"\n\n"; + if (!modelNamespace.empty()) + output << std::format("namespace {}\n{{\n\n", modelNamespace); + for (auto const& name: m_forwwardDeclarations) + output << std::format("struct {};\n", name); + output << "\n"; + output << m_definitions.str(); + if (!modelNamespace.empty()) + output << std::format("}} // end namespace {}\n", modelNamespace); + + return output.str(); + } + + void PrintTable(SqlSchema::Table const& table) + { + m_forwwardDeclarations.push_back(table.name); + + std::string cxxPrimaryKeys; + for (auto const& key: table.primaryKeys) + { + if (!cxxPrimaryKeys.empty()) + cxxPrimaryKeys += ", "; + cxxPrimaryKeys += '"' + key + '"'; + } + + m_definitions << std::format("struct {0} final: Model::Record<{0}>\n", table.name); + m_definitions << std::format("{{\n"); + + int columnPosition = 0; + for (auto const& column: table.columns) + { + ++columnPosition; + std::string type = MakeType(column); + if (column.isPrimaryKey) + continue; + if (column.isForeignKey) + continue; + m_definitions << std::format(" Model::Field<{}, {}, \"{}\"{}> {};\n", + type, + columnPosition, + column.name, + column.isNullable ? ", Nullable" : "", + column.name); + } + + columnPosition = 0; + for (auto const& foreignKey: table.foreignKeys) + { + ++columnPosition; + m_definitions << std::format(" Model::BelongsTo<{}, {}, \"{}\"> {};\n", + foreignKey.primaryKey.table, + columnPosition, + foreignKey.foreignKey.column, + MakeVariableName(foreignKey.primaryKey.table)); + } + + for (SqlSchema::ForeignKeyConstraint const& foreignKey: table.externalForeignKeys) + { + m_definitions << std::format(" Model::HasMany<{}, \"{}\"> {};\n", + foreignKey.foreignKey.table, + foreignKey.foreignKey.column, + MakePluralVariableName(foreignKey.foreignKey.table)); + } + + std::vector fieldNames; + for (auto const& column: table.columns) + if (!column.isPrimaryKey && !column.isForeignKey) + fieldNames.push_back(column.name); + + // Create default ctor + auto const cxxModelTypeName = table.name; + m_definitions << '\n'; + m_definitions << std::format(" {}():\n", cxxModelTypeName); + m_definitions << std::format(" Record {{ \"{}\", {} }}", table.name, cxxPrimaryKeys); + for (auto const& fieldName: fieldNames) + m_definitions << std::format(",\n {} {{ *this }}", fieldName, fieldName); + for (auto const& constraint: table.foreignKeys) + m_definitions << std::format(",\n {} {{ *this }}", MakeVariableName(constraint.primaryKey.table)); + for (auto const& constraint: table.externalForeignKeys) + m_definitions << std::format(",\n {} {{ *this }}", + MakePluralVariableName(constraint.foreignKey.table)); + m_definitions << "\n"; + m_definitions << " {\n"; + m_definitions << " }\n"; + + m_definitions << "\n"; + + // Create move ctor + m_definitions << std::format(" {0}({0}&& other) noexcept:\n", cxxModelTypeName); + m_definitions << std::format(" Record {{ std::move(other) }}"); + for (auto const& fieldName: fieldNames) + m_definitions << std::format(",\n {0} {{ *this, std::move(other.{0}) }}", fieldName); + for (auto const& constraint: table.foreignKeys) + m_definitions << std::format(",\n {0} {{ *this, std::move(other.{0}) }}", + MakeVariableName(constraint.primaryKey.table)); + for (auto const& constraint: table.externalForeignKeys) + m_definitions << std::format(",\n {0} {{ *this, std::move(other.{0}) }}", + MakePluralVariableName(constraint.foreignKey.table)); + m_definitions << "\n"; + m_definitions << " {\n"; + m_definitions << " }\n"; + + m_definitions << "};\n\n"; + } +}; + +void CreateTestTables() +{ + auto constexpr createStatement = R"( + CREATE TABLE User ( + id {0}, + fullname VARCHAR(128) NOT NULL, + email VARCHAR(60) NOT NULL + ); + CREATE TABLE TaskList ( + id {0}, + user_id INT NOT NULL, + CONSTRAINT fk1 FOREIGN KEY (user_id) REFERENCES user(id) + ); + CREATE TABLE TaskListEntry ( + id {0}, + tasklist_id INT NOT NULL, + completed DATETIME NULL, + task VARCHAR(255) NOT NULL, + CONSTRAINT fk1 FOREIGN KEY (tasklist_id) REFERENCES TaskList(id) + ); + )"; + auto stmt = SqlStatement(); + stmt.ExecuteDirect(std::format(createStatement, stmt.Connection().Traits().PrimaryKeyAutoIncrement)); +} + +void PostConnectedHook(SqlConnection& connection) +{ + switch (connection.ServerType()) + { + case SqlServerType::SQLITE: { + auto stmt = SqlStatement { connection }; + // Enable foreign key constraints for SQLite + (void) stmt.ExecuteDirect("PRAGMA foreign_keys = ON"); + break; + } + case SqlServerType::MICROSOFT_SQL: + case SqlServerType::POSTGRESQL: + case SqlServerType::ORACLE: + case SqlServerType::MYSQL: + case SqlServerType::UNKNOWN: + break; + } +} + +void PrintInfo() +{ + auto c = SqlConnection(); + assert(c.IsAlive()); + std::println("Connected to : {}", c.DatabaseName()); + std::println("Server name : {}", c.ServerName()); + std::println("Server version : {}", c.ServerVersion()); + std::println("User name : {}", c.UserName()); + std::println(""); +} + +} // end namespace + +struct Configuration +{ + std::string_view connectionString; + std::string_view database; + std::string_view schema; + std::string_view modelNamespace; + std::string_view outputFileName; + bool createTestTables = false; +}; + +std::expected ParseArguments(int argc, char const* argv[]) +{ + using namespace std::string_view_literals; + auto config = Configuration {}; + + int i = 1; + + for (; i < argc; ++i) + { + if (argv[i] == "--trace-sql"sv) + SqlLogger::SetLogger(SqlLogger::TraceLogger()); + else if (argv[i] == "--connection-string"sv) + { + if (++i >= argc) + return std::unexpected { EXIT_FAILURE }; + config.connectionString = argv[i]; + } + else if (argv[i] == "--database"sv) + { + if (++i >= argc) + return std::unexpected { EXIT_FAILURE }; + config.database = argv[i]; + } + else if (argv[i] == "--schema"sv) + { + if (++i >= argc) + return std::unexpected { EXIT_FAILURE }; + config.schema = argv[i]; + } + else if (argv[i] == "--create-test-tables"sv) + config.createTestTables = true; + else if (argv[i] == "--model-namespace"sv) + { + if (++i >= argc) + return std::unexpected { EXIT_FAILURE }; + config.modelNamespace = argv[i]; + } + else if (argv[i] == "--output"sv) + { + if (++i >= argc) + return std::unexpected { EXIT_FAILURE }; + config.outputFileName = argv[i]; + } + else if (argv[i] == "--help"sv || argv[i] == "-h"sv) + { + std::println("Usage: {} [options] [database] [schema]", argv[0]); + std::println("Options:"); + std::println(" --trace-sql Enable SQL tracing"); + std::println(" --connection-string STR ODBC connection string"); + std::println(" --database STR Database name"); + std::println(" --schema STR Schema name"); + std::println(" --create-test-tables Create test tables"); + std::println(" --output STR Output file name"); + std::println(" --help, -h Display this information"); + std::println(""); + return std::unexpected { EXIT_SUCCESS }; + } + else if (argv[i] == "--"sv) + { + ++i; + break; + } + else + { + std::println("Unknown option: {}", argv[i]); + return std::unexpected { EXIT_FAILURE }; + } + } + + if (i < argc) + argv[i - 1] = argv[0]; + + return { std::move(config) }; +} + +int main(int argc, char const* argv[]) +{ + auto const configOpt = ParseArguments(argc, argv); + if (!configOpt) + return configOpt.error(); + auto const config = configOpt.value(); + + SqlConnection::SetDefaultConnectInfo(SqlConnectionString { std::string(config.connectionString) }); + SqlConnection::SetPostConnectedHook(&PostConnectedHook); + + auto const _ = finally([] { SqlConnection::KillAllIdle(); }); + + if (config.createTestTables) + CreateTestTables(); + + PrintInfo(); + + std::vector tables = SqlSchema::ReadAllTables(config.database, config.schema); + CxxModelPrinter printer; + for (auto const& table: tables) + printer.PrintTable(table); + + if (config.outputFileName.empty() || config.outputFileName == "-") + std::println("{}", printer.str(config.modelNamespace)); + else + { + auto file = std::ofstream(config.outputFileName.data()); + file << printer.str(config.modelNamespace); + } + + return EXIT_SUCCESS; +} diff --git a/vcpkg-configuration.json b/vcpkg-configuration.json new file mode 100644 index 00000000..c540ec73 --- /dev/null +++ b/vcpkg-configuration.json @@ -0,0 +1,14 @@ +{ + "default-registry": { + "kind": "git", + "baseline": "000d1bda1ffa95a73e0b40334fa4103d6f4d3d48", + "repository": "https://github.com/microsoft/vcpkg" + }, + "registries": [ + { + "kind": "artifact", + "location": "https://github.com/microsoft/vcpkg-ce-catalog/archive/refs/heads/main.zip", + "name": "microsoft" + } + ] +} diff --git a/vcpkg.json b/vcpkg.json new file mode 100644 index 00000000..c2f84107 --- /dev/null +++ b/vcpkg.json @@ -0,0 +1,6 @@ +{ + "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg-tool/main/docs/vcpkg.schema.json", + "dependencies": [ + { "name": "catch2", "version>=": "3.5.2" } + ] +}