diff --git a/src/Lightweight/CMakeLists.txt b/src/Lightweight/CMakeLists.txt index e836d1c6..c5bb90a4 100644 --- a/src/Lightweight/CMakeLists.txt +++ b/src/Lightweight/CMakeLists.txt @@ -42,6 +42,8 @@ set(HEADER_FILES SqlQuery/Core.hpp SqlQuery/Delete.hpp SqlQuery/Insert.hpp + SqlQuery/Migrate.hpp + SqlQuery/MigrationPlan.hpp SqlQuery/Select.hpp SqlQuery/Update.hpp @@ -70,6 +72,8 @@ set(SOURCE_FILES SqlError.cpp SqlLogger.cpp SqlQuery.cpp + SqlQuery/Migrate.cpp + SqlQuery/MigrationPlan.cpp SqlQuery/Select.cpp SqlQueryFormatter.cpp SqlSchema.cpp diff --git a/src/Lightweight/SqlConnection.cpp b/src/Lightweight/SqlConnection.cpp index 48b85095..b7c95ece 100644 --- a/src/Lightweight/SqlConnection.cpp +++ b/src/Lightweight/SqlConnection.cpp @@ -321,3 +321,8 @@ SqlQueryBuilder SqlConnection::QueryAs(std::string_view const& table, std::strin { return SqlQueryBuilder(QueryFormatter(), std::string(table), std::string(tableAlias)); } + +SqlMigrationQueryBuilder SqlConnection::Migration() const +{ + return SqlMigrationQueryBuilder(QueryFormatter()); +} diff --git a/src/Lightweight/SqlConnection.hpp b/src/Lightweight/SqlConnection.hpp index ba1bef28..3ca9f7c8 100644 --- a/src/Lightweight/SqlConnection.hpp +++ b/src/Lightweight/SqlConnection.hpp @@ -29,6 +29,7 @@ #include class SqlQueryBuilder; +class SqlMigrationQueryBuilder; class SqlQueryFormatter; // @brief Represents a connection to a SQL database. @@ -108,12 +109,15 @@ class LIGHTWEIGHT_API SqlConnection final // Retrieves a query formatter suitable for the SQL server being connected. [[nodiscard]] SqlQueryFormatter const& QueryFormatter() const noexcept; - // Creates a new query builder for the given table, compatible with the SQL server being connected. + // Creates a new query builder for the given table, compatible with the current connection. [[nodiscard]] SqlQueryBuilder Query(std::string_view const& table = {}) const; - // Creates a new query builder for the given table with an alias, compatible with the SQL server being connected. + // Creates a new query builder for the given table with an alias, compatible with the current connection. [[nodiscard]] SqlQueryBuilder QueryAs(std::string_view const& table, std::string_view const& tableAlias) const; + // Creates a new migration query builder, compatible the current connection. + [[nodiscard]] SqlMigrationQueryBuilder Migration() const; + // Retrieves the SQL traits for the server. [[nodiscard]] SqlTraits const& Traits() const noexcept { diff --git a/src/Lightweight/SqlQuery.cpp b/src/Lightweight/SqlQuery.cpp index 4b826066..12cca6ac 100644 --- a/src/Lightweight/SqlQuery.cpp +++ b/src/Lightweight/SqlQuery.cpp @@ -34,3 +34,8 @@ SqlDeleteQueryBuilder SqlQueryBuilder::Delete() noexcept { return SqlDeleteQueryBuilder(m_formatter, std::move(m_table), std::move(m_tableAlias)); } + +LIGHTWEIGHT_API SqlMigrationQueryBuilder SqlQueryBuilder::Migration() +{ + return SqlMigrationQueryBuilder(m_formatter); +} diff --git a/src/Lightweight/SqlQuery.hpp b/src/Lightweight/SqlQuery.hpp index bfd6f31d..8f35f67c 100644 --- a/src/Lightweight/SqlQuery.hpp +++ b/src/Lightweight/SqlQuery.hpp @@ -4,6 +4,7 @@ #include "Api.hpp" #include "SqlQuery/Delete.hpp" #include "SqlQuery/Insert.hpp" +#include "SqlQuery/Migrate.hpp" #include "SqlQuery/Select.hpp" #include "SqlQuery/Update.hpp" @@ -42,6 +43,8 @@ class [[nodiscard]] SqlQueryBuilder final // Initiates DELETE query building LIGHTWEIGHT_API SqlDeleteQueryBuilder Delete() noexcept; + LIGHTWEIGHT_API SqlMigrationQueryBuilder Migration(); + private: SqlQueryFormatter const& m_formatter; std::string m_table; diff --git a/src/Lightweight/SqlQuery/Migrate.cpp b/src/Lightweight/SqlQuery/Migrate.cpp new file mode 100644 index 00000000..2aced77c --- /dev/null +++ b/src/Lightweight/SqlQuery/Migrate.cpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "../SqlQueryFormatter.hpp" +#include "Migrate.hpp" + +SqlMigrationPlan SqlMigrationQueryBuilder::GetPlan() +{ + return _migrationPlan; +} + +SqlMigrationQueryBuilder& SqlMigrationQueryBuilder::DropTable(std::string_view tableName) +{ + _migrationPlan.steps.emplace_back(SqlDropTablePlan { + .tableName = tableName, + }); + return *this; +} + +SqlCreateTableQueryBuilder SqlMigrationQueryBuilder::CreateTable(std::string_view tableName) +{ + _migrationPlan.steps.emplace_back(SqlCreateTablePlan { + .tableName = tableName, + .columns = {}, + }); + return SqlCreateTableQueryBuilder { std::get(_migrationPlan.steps.back()) }; +} + +SqlAlterTableQueryBuilder SqlMigrationQueryBuilder::AlterTable(std::string_view tableName) +{ + _migrationPlan.steps.emplace_back(SqlAlterTablePlan { + .tableName = tableName, + .commands = {}, + }); + return SqlAlterTableQueryBuilder { std::get(_migrationPlan.steps.back()) }; +} + +SqlAlterTableQueryBuilder& SqlAlterTableQueryBuilder::RenameTo(std::string_view newTableName) +{ + _plan.commands.emplace_back(SqlAlterTableCommands::RenameTable { + .newTableName = newTableName, + }); + return *this; +} + +SqlAlterTableQueryBuilder& SqlAlterTableQueryBuilder::AddColumn(std::string_view columnName, + SqlColumnTypeDefinition columnType) +{ + _plan.commands.emplace_back(SqlAlterTableCommands::AddColumn { + .columnName = columnName, + .columnType = columnType, + }); + return *this; +} + +SqlAlterTableQueryBuilder& SqlAlterTableQueryBuilder::RenameColumn(std::string_view oldColumnName, + std::string_view newColumnName) +{ + _plan.commands.emplace_back(SqlAlterTableCommands::RenameColumn { + .oldColumnName = oldColumnName, + .newColumnName = newColumnName, + }); + return *this; +} + +SqlAlterTableQueryBuilder& SqlAlterTableQueryBuilder::DropColumn(std::string_view columnName) +{ + _plan.commands.emplace_back(SqlAlterTableCommands::DropColumn { + .columnName = columnName, + }); + return *this; +} + +SqlAlterTableQueryBuilder& SqlAlterTableQueryBuilder::AddIndex(std::string_view columnName) +{ + _plan.commands.emplace_back(SqlAlterTableCommands::AddIndex { + .columnName = columnName, + .unique = false, + }); + return *this; +} + +SqlAlterTableQueryBuilder& SqlAlterTableQueryBuilder::AddUniqueIndex(std::string_view columnName) +{ + _plan.commands.emplace_back(SqlAlterTableCommands::AddIndex { + .columnName = columnName, + .unique = true, + }); + return *this; +} + +SqlAlterTableQueryBuilder& SqlAlterTableQueryBuilder::DropIndex(std::string_view columnName) +{ + _plan.commands.emplace_back(SqlAlterTableCommands::DropIndex { + .columnName = columnName, + }); + return *this; +} + +SqlCreateTableQueryBuilder& SqlCreateTableQueryBuilder::Column(SqlColumnDeclaration column) +{ + _plan.columns.emplace_back(std::move(column)); + return *this; +} + +SqlCreateTableQueryBuilder& SqlCreateTableQueryBuilder::Column(std::string columnName, + SqlColumnTypeDefinition columnType) +{ + return Column(SqlColumnDeclaration { + .name = std::move(columnName), + .type = columnType, + }); +} + +SqlCreateTableQueryBuilder& SqlCreateTableQueryBuilder::RequiredColumn(std::string columnName, + SqlColumnTypeDefinition columnType) +{ + return Column(SqlColumnDeclaration { + .name = std::move(columnName), + .type = columnType, + .required = true, + }); +} + +SqlCreateTableQueryBuilder& SqlCreateTableQueryBuilder::PrimaryKey(std::string columnName, + SqlColumnTypeDefinition columnType) +{ + return Column(SqlColumnDeclaration { + .name = std::move(columnName), + .type = columnType, + .primaryKey = SqlPrimaryKeyType::MANUAL, + .required = true, + .unique = true, + .index = true, + }); +} + +SqlCreateTableQueryBuilder& SqlCreateTableQueryBuilder::PrimaryKeyWithAutoIncrement(std::string columnName, + SqlColumnTypeDefinition columnType) +{ + return Column(SqlColumnDeclaration { + .name = std::move(columnName), + .type = columnType, + .primaryKey = SqlPrimaryKeyType::AUTO_INCREMENT, + .required = true, + .unique = true, + .index = true, + }); +} + +SqlCreateTableQueryBuilder& SqlCreateTableQueryBuilder::Unique() +{ + _plan.columns.back().unique = true; + return *this; +} + +SqlCreateTableQueryBuilder& SqlCreateTableQueryBuilder::Index() +{ + _plan.columns.back().index = true; + return *this; +} + +SqlCreateTableQueryBuilder& SqlCreateTableQueryBuilder::UniqueIndex() +{ + _plan.columns.back().index = true; + _plan.columns.back().unique = true; + return *this; +} diff --git a/src/Lightweight/SqlQuery/Migrate.hpp b/src/Lightweight/SqlQuery/Migrate.hpp new file mode 100644 index 00000000..c80b88e9 --- /dev/null +++ b/src/Lightweight/SqlQuery/Migrate.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "Core.hpp" +#include "MigrationPlan.hpp" + +#include + +class [[nodiscard]] SqlCreateTableQueryBuilder final +{ + public: + explicit SqlCreateTableQueryBuilder(SqlCreateTablePlan& plan): + _plan { plan } + { + } + + // Adds a new column to the table. + LIGHTWEIGHT_API SqlCreateTableQueryBuilder& Column(SqlColumnDeclaration column); + + // Creates a new nullable column. + LIGHTWEIGHT_API SqlCreateTableQueryBuilder& Column(std::string columnName, SqlColumnTypeDefinition columnType); + + // Creates a new column that is non-nullable. + LIGHTWEIGHT_API SqlCreateTableQueryBuilder& RequiredColumn(std::string columnName, + SqlColumnTypeDefinition columnType); + + // Creates a new primary key column. + // Primary keys are always required, unique, have an index, and are non-nullable. + LIGHTWEIGHT_API SqlCreateTableQueryBuilder& PrimaryKey(std::string columnName, SqlColumnTypeDefinition columnType); + + LIGHTWEIGHT_API SqlCreateTableQueryBuilder& PrimaryKeyWithAutoIncrement( + std::string columnName, SqlColumnTypeDefinition columnType = SqlColumnTypeDefinitions::Bigint {}); + + // Enables the UNIQUE constraint on the last declared column. + LIGHTWEIGHT_API SqlCreateTableQueryBuilder& Unique(); + + // Enables the UNIQUE constraint on the last declared column. + LIGHTWEIGHT_API SqlCreateTableQueryBuilder& Index(); + + // Enables the UNIQUE constraint on the last declared column and makes it an index. + LIGHTWEIGHT_API SqlCreateTableQueryBuilder& UniqueIndex(); + + private: + SqlCreateTablePlan& _plan; +}; + +class [[nodiscard]] SqlAlterTableQueryBuilder final +{ + public: + explicit SqlAlterTableQueryBuilder(SqlAlterTablePlan& plan): + _plan { plan } + { + } + + // Renames the table. + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& RenameTo(std::string_view newTableName); + + // Adds a new column to the table that is non-nullable. + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& AddColumn(std::string_view columnName, + SqlColumnTypeDefinition columnType); + + // Adds a new column to the table that is nullable. + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& AddColumnAsNullable(std::string_view columnName, + SqlColumnTypeDefinition columnType); + + // Alters the column to have a new non-nullable type. + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& AlterColumn(std::string_view columnName, + SqlColumnTypeDefinition columnType); + + // Alters the column to have a new nullable type. + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& AlterColumnAsNullable(std::string_view columnName, + SqlColumnTypeDefinition columnType); + + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& RenameColumn(std::string_view oldColumnName, + std::string_view newColumnName); + + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& DropColumn(std::string_view columnName); + + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& AddIndex(std::string_view columnName); + + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& AddUniqueIndex(std::string_view columnName); + + LIGHTWEIGHT_API SqlAlterTableQueryBuilder& DropIndex(std::string_view columnName); + + private: + SqlAlterTablePlan& _plan; +}; + +class [[nodiscard]] SqlMigrationQueryBuilder final +{ + public: + explicit SqlMigrationQueryBuilder(SqlQueryFormatter const& formatter): + _formatter { formatter }, + _migrationPlan { .formatter = formatter } + { + } + + LIGHTWEIGHT_API SqlMigrationQueryBuilder& CreateDatabase(std::string_view databaseName); + LIGHTWEIGHT_API SqlMigrationQueryBuilder& DropDatabase(std::string_view databaseName); + + LIGHTWEIGHT_API SqlCreateTableQueryBuilder CreateTable(std::string_view tableName); + LIGHTWEIGHT_API SqlAlterTableQueryBuilder AlterTable(std::string_view tableName); + LIGHTWEIGHT_API SqlMigrationQueryBuilder& DropTable(std::string_view tableName); + + LIGHTWEIGHT_API SqlMigrationQueryBuilder& RawSql(std::string_view sql); + LIGHTWEIGHT_API SqlMigrationQueryBuilder& Native(std::function callback); + + LIGHTWEIGHT_API SqlMigrationQueryBuilder& BeginTransaction(); + LIGHTWEIGHT_API SqlMigrationQueryBuilder& CommitTransaction(); + + LIGHTWEIGHT_API SqlMigrationPlan GetPlan(); + + private: + SqlQueryFormatter const& _formatter; + SqlMigrationPlan _migrationPlan; +}; diff --git a/src/Lightweight/SqlQuery/MigrationPlan.cpp b/src/Lightweight/SqlQuery/MigrationPlan.cpp new file mode 100644 index 00000000..a87c4154 --- /dev/null +++ b/src/Lightweight/SqlQuery/MigrationPlan.cpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "../SqlQueryFormatter.hpp" +#include "MigrationPlan.hpp" + +std::string SqlMigrationPlan::ToSql() const +{ + std::string result; + for (auto const& step: steps) + result += ToSql(formatter, step); + return result; +} + +std::string SqlMigrationPlan::ToSql(SqlQueryFormatter const& formatter, SqlMigrationPlanElement const& element) +{ + using namespace std::string_literals; + return std::visit( + [&](auto const& step) { + if constexpr (std::is_same_v, SqlCreateTablePlan>) + { + return formatter.CreateTable(step.tableName, step.columns); + } + else if constexpr (std::is_same_v, SqlAlterTablePlan>) + { + return formatter.AlterTable(step.tableName, step.commands); + } + else if constexpr (std::is_same_v, SqlDropTablePlan>) + { + return formatter.DropTable(step.tableName); + } + else + { + static_assert(false, "non-exhaustive visitor"); + } + }, + element); +} diff --git a/src/Lightweight/SqlQuery/MigrationPlan.hpp b/src/Lightweight/SqlQuery/MigrationPlan.hpp new file mode 100644 index 00000000..9f8c7db5 --- /dev/null +++ b/src/Lightweight/SqlQuery/MigrationPlan.hpp @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "../Api.hpp" + +#include + +#include +#include +#include +#include + +class SqlQueryFormatter; + +// clang-format off +namespace SqlColumnTypeDefinitions +{ + +struct Bool {}; +struct Char { size_t size = 1; }; +struct Varchar { size_t size {}; }; +struct Text { size_t size {}; }; +struct Smallint {}; +struct Integer {}; +struct Bigint {}; +struct Real {}; +struct Decimal { size_t precision {}; size_t scale {}; }; +struct DateTime {}; +struct Timestamp {}; +struct Date {}; +struct Time {}; +struct Guid {}; + +} // namespace SqlColumnTypeDefinitions + +using SqlColumnTypeDefinition = std::variant< + SqlColumnTypeDefinitions::Bigint, + SqlColumnTypeDefinitions::Bool, + SqlColumnTypeDefinitions::Char, + SqlColumnTypeDefinitions::Date, + SqlColumnTypeDefinitions::DateTime, + SqlColumnTypeDefinitions::Decimal, + SqlColumnTypeDefinitions::Guid, + SqlColumnTypeDefinitions::Integer, + SqlColumnTypeDefinitions::Real, + SqlColumnTypeDefinitions::Smallint, + SqlColumnTypeDefinitions::Text, + SqlColumnTypeDefinitions::Time, + SqlColumnTypeDefinitions::Timestamp, + SqlColumnTypeDefinitions::Varchar +>; +// clang-format on + +enum class SqlPrimaryKeyType : uint8_t +{ + NONE, + MANUAL, + AUTO_INCREMENT, + GUID, +}; + +struct SqlColumnDeclaration +{ + std::string name; + SqlColumnTypeDefinition type; + SqlPrimaryKeyType primaryKey { SqlPrimaryKeyType::NONE }; + bool required { false }; + bool unique { false }; + bool index { false }; +}; + +struct SqlCreateTablePlan +{ + std::string_view tableName; + std::vector columns; +}; + +namespace SqlAlterTableCommands +{ + +struct RenameTable +{ + std::string_view newTableName; +}; + +struct AddColumn +{ + std::string_view columnName; + SqlColumnTypeDefinition columnType; +}; + +struct AddIndex +{ + std::string_view columnName; + bool unique = false; +}; + +struct RenameColumn +{ + std::string_view oldColumnName; + std::string_view newColumnName; +}; + +struct DropColumn +{ + std::string_view columnName; +}; + +struct DropIndex +{ + std::string_view columnName; +}; + +}; // namespace SqlAlterTableCommands + +using SqlAlterTableCommand = std::variant; + +struct SqlAlterTablePlan +{ + std::string_view tableName; + std::vector commands; +}; + +struct SqlDropTablePlan +{ + std::string_view tableName; +}; + +// clang-format off +using SqlMigrationPlanElement = std::variant< + SqlCreateTablePlan, + SqlAlterTablePlan, + SqlDropTablePlan +>; +// clang-format on + +struct [[nodiscard]] SqlMigrationPlan +{ + SqlQueryFormatter const& formatter; + std::vector steps {}; + + [[nodiscard]] LIGHTWEIGHT_API std::string ToSql() const; + + [[nodiscard]] LIGHTWEIGHT_API static std::string ToSql(SqlQueryFormatter const& formatter, + SqlMigrationPlanElement const& element); +}; diff --git a/src/Lightweight/SqlQueryFormatter.cpp b/src/Lightweight/SqlQueryFormatter.cpp index c54dce95..fa3fdbb9 100644 --- a/src/Lightweight/SqlQueryFormatter.cpp +++ b/src/Lightweight/SqlQueryFormatter.cpp @@ -2,7 +2,10 @@ #include "SqlQueryFormatter.hpp" +#include + #include +#include #include using namespace std::string_view_literals; @@ -154,6 +157,181 @@ class BasicSqlQueryFormatter: public SqlQueryFormatter return std::format( R"(DELETE FROM "{}" AS "{}"{}{})", fromTable, fromTableAlias, tableJoins, whereCondition); } + + [[nodiscard]] virtual std::string BuildColumnDefinition(SqlColumnDeclaration const& column) const + { + std::stringstream sqlQueryString; + sqlQueryString << '"' << column.name << "\" "; + + if (column.primaryKey != SqlPrimaryKeyType::AUTO_INCREMENT) + sqlQueryString << ColumnType(column.type); + else + sqlQueryString << ColumnType(SqlColumnTypeDefinitions::Integer {}); + + if (column.required) + sqlQueryString << " NOT NULL"; + + if (column.primaryKey != SqlPrimaryKeyType::NONE) + sqlQueryString << " PRIMARY KEY"; + else if (column.unique && !column.index) + sqlQueryString << " UNIQUE"; + + if (column.primaryKey == SqlPrimaryKeyType::AUTO_INCREMENT) + sqlQueryString << " AUTOINCREMENT"; + + return sqlQueryString.str(); + } + + [[nodiscard]] std::string CreateTable(std::string_view tableName, + std::vector const& columns) const override + { + std::stringstream sqlQueryString; + + sqlQueryString << "CREATE TABLE \"" << tableName << "\" ("; + + size_t currentColumn = 0; + for (SqlColumnDeclaration const& column: columns) + { + if (currentColumn > 0) + sqlQueryString << ","; + ++currentColumn; + sqlQueryString << "\n "; + sqlQueryString << BuildColumnDefinition(column); + } + sqlQueryString << "\n);"; + + for (SqlColumnDeclaration const& column: columns) + { + if (column.index && column.primaryKey == SqlPrimaryKeyType::NONE) + { + // primary keys are always indexed + if (column.unique) + sqlQueryString << std::format("\nCREATE UNIQUE INDEX \"{}_{}_index\" ON \"{}\"(\"{}\");", + tableName, + column.name, + tableName, + column.name); + else + sqlQueryString << std::format("\nCREATE INDEX \"{}_{}_index\" ON \"{}\"(\"{}\");", + tableName, + column.name, + tableName, + column.name); + } + } + + return sqlQueryString.str(); + } + + [[nodiscard]] std::string AlterTable(std::string_view tableName, + std::vector const& commands) const override + { + std::stringstream sqlQueryString; + + int currentCommand = 0; + for (SqlAlterTableCommand const& command: commands) + { + if (currentCommand > 0) + sqlQueryString << '\n'; + ++currentCommand; + + // sqlQueryString << "ALTER TABLE \"" << tableName << "\" "; + + sqlQueryString << std::visit( + [this, tableName](auto const& actualCommand) -> std::string { + using Type = std::decay_t; + if constexpr (std::same_as) + { + return std::format( + R"(ALTER TABLE "{}" RENAME TO "{}";)", tableName, actualCommand.newTableName); + } + else if constexpr (std::same_as) + { + return std::format(R"(ALTER TABLE "{}" ADD COLUMN "{}" {};)", + tableName, + actualCommand.columnName, + ColumnType(actualCommand.columnType)); + } + else if constexpr (std::same_as) + { + return std::format(R"(ALTER TABLE "{}" RENAME COLUMN "{}" TO "{}";)", + tableName, + actualCommand.oldColumnName, + actualCommand.newColumnName); + } + else if constexpr (std::same_as) + { + return std::format( + R"(ALTER TABLE "{}" DROP COLUMN "{}";)", tableName, actualCommand.columnName); + } + else if constexpr (std::same_as) + { + auto const uniqueStr = actualCommand.unique ? "UNIQUE "sv : ""sv; + return std::format(R"(CREATE {2}INDEX "{0}_{1}_index" ON "{0}"("{1}");)", + tableName, + actualCommand.columnName, + uniqueStr); + } + else if constexpr (std::same_as) + { + return std::format(R"(DROP INDEX "{0}_{1}_index";)", tableName, actualCommand.columnName); + } + else + { + throw std::runtime_error( + std::format("Unknown alter table command: {}", Reflection::TypeName)); + } + }, + command); + } + + return sqlQueryString.str(); + } + + [[nodiscard]] std::string ColumnType(SqlColumnTypeDefinition const& type) const override + { + using namespace SqlColumnTypeDefinitions; + return std::visit( + [](auto const& actualType) -> std::string { + using Type = std::decay_t; + if constexpr (std::same_as) + return "BIGINT"; + else if constexpr (std::same_as) + return "BOOLEAN"; + else if constexpr (std::same_as) + return std::format("CHAR({})", actualType.size); + else if constexpr (std::same_as) + return "DATE"; + else if constexpr (std::same_as) + return "DATETIME"; + else if constexpr (std::same_as) + return std::format("DECIMAL({}, {})", actualType.precision, actualType.scale); + else if constexpr (std::same_as) + return "GUID"; + else if constexpr (std::same_as) + return "INTEGER"; + else if constexpr (std::same_as) + return "REAL"; + else if constexpr (std::same_as) + return "SMALLINT"; + else if constexpr (std::same_as) + return "TEXT"; + else if constexpr (std::same_as) + return "TIME"; + else if constexpr (std::same_as) + return "TIMESTAMP"; + else if constexpr (std::same_as) + return std::format("VARCHAR({})", actualType.size); + else + throw std::runtime_error(std::format("Unknown column type: {}", Reflection::TypeName)); + }, + type); + } + + [[nodiscard]] std::string DropTable(std::string_view const& tableName) const override + { + return std::format(R"(DROP TABLE "{}";)", tableName); + } }; class SqlServerQueryFormatter final: public BasicSqlQueryFormatter @@ -215,6 +393,85 @@ class SqlServerQueryFormatter final: public BasicSqlQueryFormatter sqlQueryString << " OFFSET " << offset << " ROWS FETCH NEXT " << limit << " ROWS ONLY"; return sqlQueryString.str(); } + + [[nodiscard]] std::string ColumnType(SqlColumnTypeDefinition const& type) const override + { + using namespace SqlColumnTypeDefinitions; + return std::visit( + [this, type](auto const& actualType) -> std::string { + using Type = std::decay_t; + if constexpr (std::same_as) + return "BIT"; + else if constexpr (std::same_as) + return "UNIQUEIDENTIFIER"; + else if constexpr (std::same_as) + return "VARCHAR(MAX)"; + else + return BasicSqlQueryFormatter::ColumnType(type); + }, + type); + } + + [[nodiscard]] std::string BuildColumnDefinition(SqlColumnDeclaration const& column) const override + { + std::stringstream sqlQueryString; + sqlQueryString << '"' << column.name << "\" " << ColumnType(column.type); + + if (column.required) + sqlQueryString << " NOT NULL"; + + if (column.primaryKey == SqlPrimaryKeyType::AUTO_INCREMENT) + sqlQueryString << " IDENTITY(1,1)"; + + if (column.primaryKey != SqlPrimaryKeyType::NONE) + sqlQueryString << " PRIMARY KEY"; + + if (column.unique && !column.index) + sqlQueryString << " UNIQUE"; + + return sqlQueryString.str(); + } +}; + +class PostgreSqlFormatter final: public BasicSqlQueryFormatter +{ + public: + [[nodiscard]] std::string BuildColumnDefinition(SqlColumnDeclaration const& column) const override + { + std::stringstream sqlQueryString; + + sqlQueryString << '"' << column.name << "\" "; + + if (column.primaryKey == SqlPrimaryKeyType::AUTO_INCREMENT) + sqlQueryString << "SERIAL"; + else + sqlQueryString << ColumnType(column.type); + + if (column.required) + sqlQueryString << " NOT NULL"; + + if (column.primaryKey != SqlPrimaryKeyType::NONE) + sqlQueryString << " PRIMARY KEY"; + + if (column.unique && !column.index) + sqlQueryString << " UNIQUE"; + + return sqlQueryString.str(); + } + + [[nodiscard]] std::string ColumnType(SqlColumnTypeDefinition const& type) const override + { + using namespace SqlColumnTypeDefinitions; + return std::visit( + [this, type](auto const& actualType) -> std::string { + using Type = std::decay_t; + if constexpr (std::same_as) + return "UUID"; + else + return BasicSqlQueryFormatter::ColumnType(type); + }, + type); + } }; } // namespace @@ -233,7 +490,7 @@ SqlQueryFormatter const& SqlQueryFormatter::SqlServer() SqlQueryFormatter const& SqlQueryFormatter::PostgrSQL() { - static const BasicSqlQueryFormatter formatter {}; + static const PostgreSqlFormatter formatter {}; return formatter; } diff --git a/src/Lightweight/SqlQueryFormatter.hpp b/src/Lightweight/SqlQueryFormatter.hpp index bc92adfb..bb6dd170 100644 --- a/src/Lightweight/SqlQueryFormatter.hpp +++ b/src/Lightweight/SqlQueryFormatter.hpp @@ -4,6 +4,7 @@ #include "Api.hpp" #include "SqlConnection.hpp" +#include "SqlQuery/MigrationPlan.hpp" #include #include @@ -76,6 +77,11 @@ class [[nodiscard]] LIGHTWEIGHT_API SqlQueryFormatter std::string const& tableJoins, std::string const& whereCondition) const = 0; + [[nodiscard]] virtual std::string ColumnType(SqlColumnTypeDefinition const& type) const = 0; + [[nodiscard]] virtual std::string CreateTable(std::string_view tableName, std::vector const& columns) const = 0; + [[nodiscard]] virtual std::string AlterTable(std::string_view tableName, std::vector const& commands) const = 0; + [[nodiscard]] virtual std::string DropTable(std::string_view const& tableName) const = 0; + static SqlQueryFormatter const& Sqlite(); static SqlQueryFormatter const& SqlServer(); static SqlQueryFormatter const& PostgrSQL(); diff --git a/src/Lightweight/SqlStatement.hpp b/src/Lightweight/SqlStatement.hpp index acd9bbe1..2e94eabe 100644 --- a/src/Lightweight/SqlStatement.hpp +++ b/src/Lightweight/SqlStatement.hpp @@ -165,6 +165,10 @@ class SqlStatement final: public SqlDataBinderCallback void ExecuteDirect(SqlQueryObject auto const& query, std::source_location location = std::source_location::current()); + // Executes an SQL migration query, as created b the callback. + template + void MigrateDirect(Callable const& callable, std::source_location location = std::source_location::current()); + // Executes the given query, assuming that only one result row and column is affected, that one will be // returned. template @@ -647,6 +651,14 @@ inline LIGHTWEIGHT_FORCE_INLINE void SqlStatement::ExecuteDirect(SqlQueryObject return ExecuteDirect(query.ToSql(), location); } +template +void SqlStatement::MigrateDirect(Callable const& callable, std::source_location location) +{ + auto migration = SqlMigrationQueryBuilder { Connection().QueryFormatter() }; + callable(migration); + ExecuteDirect(migration.GetPlan(), location); +} + template requires(!std::same_as) inline std::optional SqlStatement::ExecuteDirectScalar(const std::string_view& query, std::source_location location) diff --git a/src/tests/CoreTests.cpp b/src/tests/CoreTests.cpp index 9f7f184f..1af639e8 100644 --- a/src/tests/CoreTests.cpp +++ b/src/tests/CoreTests.cpp @@ -127,21 +127,21 @@ TEST_CASE_METHOD(SqlTestFixture, "execute bound parameters and select back: VARC CreateEmployeesTable(stmt); REQUIRE(!stmt.IsPrepared()); - stmt.Prepare("INSERT INTO Employees (FirstName, LastName, Salary) VALUES (?, ?, ?)"); + stmt.Prepare(R"(INSERT INTO "Employees" ("FirstName", "LastName", "Salary") VALUES (?, ?, ?))"); REQUIRE(stmt.IsPrepared()); stmt.Execute("Alice", "Smith", 50'000); stmt.Execute("Bob", "Johnson", 60'000); stmt.Execute("Charlie", "Brown", 70'000); - stmt.ExecuteDirect("SELECT COUNT(*) FROM Employees"); + stmt.ExecuteDirect(R"(SELECT COUNT(*) FROM "Employees")"); REQUIRE(!stmt.IsPrepared()); REQUIRE(stmt.NumColumnsAffected() == 1); (void) stmt.FetchRow(); REQUIRE(stmt.GetColumn(1) == 3); REQUIRE(!stmt.FetchRow()); - stmt.Prepare("SELECT FirstName, LastName, Salary FROM Employees WHERE Salary >= ?"); + stmt.Prepare(R"(SELECT "FirstName", "LastName", "Salary" FROM "Employees" WHERE "Salary" >= ?)"); REQUIRE(stmt.NumColumnsAffected() == 3); stmt.Execute(55'000); @@ -262,7 +262,12 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlStatement.ExecuteBatchNative") auto stmt = SqlStatement {}; UNSUPPORTED_DATABASE(stmt, SqlServerType::ORACLE); - stmt.ExecuteDirect("CREATE TABLE Test (A VARCHAR(8), B REAL, C INTEGER)"); + stmt.MigrateDirect([](SqlMigrationQueryBuilder& migration) { + migration.CreateTable("Test") + .Column("A", SqlColumnTypeDefinitions::Varchar { 8 }) + .Column("B", SqlColumnTypeDefinitions::Real {}) + .Column("C", SqlColumnTypeDefinitions::Integer {}); + }); stmt.Prepare("INSERT INTO Test (A, B, C) VALUES (?, ?, ?)"); @@ -364,7 +369,11 @@ TEST_CASE_METHOD(SqlTestFixture, "SELECT * FROM Table") TEST_CASE_METHOD(SqlTestFixture, "GetNullableColumn") { auto stmt = SqlStatement {}; - stmt.ExecuteDirect("CREATE TABLE Test (Remarks1 VARCHAR(50) NULL, Remarks2 VARCHAR(50) NULL)"); + stmt.MigrateDirect([](SqlMigrationQueryBuilder& migration) { + migration.CreateTable("Test") + .Column("Remarks1", SqlColumnTypeDefinitions::Varchar { 50 }) + .Column("Remarks2", SqlColumnTypeDefinitions::Varchar { 50 }); + }); stmt.Prepare("INSERT INTO Test (Remarks1, Remarks2) VALUES (?, ?)"); stmt.Execute("Blurb", SqlNullValue); diff --git a/src/tests/QueryBuilderTests.cpp b/src/tests/QueryBuilderTests.cpp index dd3add4b..21750fe9 100644 --- a/src/tests/QueryBuilderTests.cpp +++ b/src/tests/QueryBuilderTests.cpp @@ -10,17 +10,21 @@ #include #include +#include #include #include struct QueryExpectations { std::string_view sqlite; + std::string_view postgres; std::string_view sqlServer; + std::string_view oracle; static QueryExpectations All(std::string_view query) { - return { query, query }; + // NOLINTNEXTLINE(modernize-use-designated-initializers) + return { query, query, query, query }; } }; @@ -32,6 +36,25 @@ auto EraseLinefeeds(std::string str) noexcept -> std::string return str; } +[[nodiscard]] std::string NormalizeText(std::string_view const& text) +{ + auto result = std::string(text); + + // Remove any newlines and reduce all whitespace to a single space + result.erase( + std::unique(result.begin(), result.end(), [](char a, char b) { return std::isspace(a) && std::isspace(b); }), + result.end()); + + // trim lading and trailing whitespace + while (!result.empty() && std::isspace(result.front())) + result.erase(result.begin()); + + while (!result.empty() && std::isspace(result.back())) + result.pop_back(); + + return result; +} + template requires(std::is_invocable_v) void checkSqlQueryBuilder(TheSqlQuery const& sqlQueryBuilder, @@ -41,19 +64,21 @@ void checkSqlQueryBuilder(TheSqlQuery const& sqlQueryBuilder, { INFO(std::format("Test source location: {}:{}", location.file_name(), location.line())); - auto const& sqliteFormatter = SqlQueryFormatter::Sqlite(); - auto sqliteQueryBuilder = SqlQueryBuilder(sqliteFormatter); - auto const actualSqlite = EraseLinefeeds(sqlQueryBuilder(sqliteQueryBuilder).ToSql()); - CHECK(actualSqlite == expectations.sqlite); - if (postCheck) - postCheck(); - - auto const& sqlServerFormatter = SqlQueryFormatter::SqlServer(); - auto sqlServerQueryBuilder = SqlQueryBuilder(sqlServerFormatter); - auto const actualSqlServer = EraseLinefeeds(sqlQueryBuilder(sqlServerQueryBuilder).ToSql()); - CHECK(actualSqlServer == expectations.sqlServer); - if (postCheck) - postCheck(); + auto const checkOne = [&](SqlQueryFormatter const& formatter, std::string_view name, std::string_view query) { + INFO("Testing " << name); + auto sqliteQueryBuilder = SqlQueryBuilder(formatter); + auto const sqlQuery = sqlQueryBuilder(sqliteQueryBuilder); + auto const actual = NormalizeText(sqlQuery.ToSql()); + auto const expected = NormalizeText(query); + REQUIRE(actual == expected); + if (postCheck) + postCheck(); + }; + + checkOne(SqlQueryFormatter::Sqlite(), "SQLite", expectations.sqlite); + checkOne(SqlQueryFormatter::PostgrSQL(), "Postgres", expectations.postgres); + checkOne(SqlQueryFormatter::SqlServer(), "SQL Server", expectations.sqlServer); + // TODO: checkOne(SqlQueryFormatter::OracleSQL(), "Oracle", expectations.oracle); } TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Select.Count", "[SqlQueryBuilder]") @@ -68,7 +93,10 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Select.All", "[SqlQueryBuilder [](SqlQueryBuilder& q) { return q.FromTable("That").Select().Fields("a", "b").Field("c").GroupBy("a").OrderBy("b").All(); }, - QueryExpectations::All(R"(SELECT "a", "b", "c" FROM "That" GROUP BY "a" ORDER BY "b" ASC)")); + QueryExpectations::All(R"( + SELECT "a", "b", "c" FROM "That" + GROUP BY "a" + ORDER BY "b" ASC)")); } TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Select.Distinct.All", "[SqlQueryBuilder]") @@ -77,7 +105,10 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Select.Distinct.All", "[SqlQue [](SqlQueryBuilder& q) { return q.FromTable("That").Select().Distinct().Fields("a", "b").Field("c").GroupBy("a").OrderBy("b").All(); }, - QueryExpectations::All(R"(SELECT DISTINCT "a", "b", "c" FROM "That" GROUP BY "a" ORDER BY "b" ASC)")); + QueryExpectations::All(R"( + SELECT DISTINCT "a", "b", "c" FROM "That" + GROUP BY "a" + ORDER BY "b" ASC)")); } TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Select.First", "[SqlQueryBuilder]") @@ -85,8 +116,14 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Select.First", "[SqlQueryBuild checkSqlQueryBuilder( [](SqlQueryBuilder& q) { return q.FromTable("That").Select().Field("field1").OrderBy("id").First(); }, QueryExpectations { - .sqlite = R"(SELECT "field1" FROM "That" ORDER BY "id" ASC LIMIT 1)", - .sqlServer = R"(SELECT TOP 1 "field1" FROM "That" ORDER BY "id" ASC)", + .sqlite = R"(SELECT "field1" FROM "That" + ORDER BY "id" ASC LIMIT 1)", + .postgres = R"(SELECT "field1" FROM "That" + ORDER BY "id" ASC LIMIT 1)", + .sqlServer = R"(SELECT TOP 1 "field1" FROM "That" + ORDER BY "id" ASC)", + .oracle = R"(SELECT "field1" FROM "That" + ORDER BY "id" ASC FETCH FIRST 1 ROWS ONLY)", }); } @@ -97,8 +134,14 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Select.Range", "[SqlQueryBuild return q.FromTable("That").Select().Fields("foo", "bar").OrderBy("id").Range(200, 50); }, QueryExpectations { - .sqlite = R"(SELECT "foo", "bar" FROM "That" ORDER BY "id" ASC LIMIT 50 OFFSET 200)", - .sqlServer = R"(SELECT "foo", "bar" FROM "That" ORDER BY "id" ASC OFFSET 200 ROWS FETCH NEXT 50 ROWS ONLY)", + .sqlite = R"(SELECT "foo", "bar" FROM "That" + ORDER BY "id" ASC LIMIT 50 OFFSET 200)", + .postgres = R"(SELECT "foo", "bar" FROM "That" + ORDER BY "id" ASC LIMIT 50 OFFSET 200)", + .sqlServer = R"(SELECT "foo", "bar" FROM "That" + ORDER BY "id" ASC OFFSET 200 ROWS FETCH NEXT 50 ROWS ONLY)", + .oracle = R"(SELECT "foo", "bar" FROM "That" + ORDER BY "id" ASC OFFSET 200 ROWS FETCH NEXT 50 ROWS ONLY)", }); } @@ -113,7 +156,9 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Fields", "[SqlQueryBuilder]") checkSqlQueryBuilder([](SqlQueryBuilder& q) { return q.FromTable("Users").Select().Fields().First(); }, QueryExpectations { .sqlite = R"(SELECT "name", "address" FROM "Users" LIMIT 1)", + .postgres = R"(SELECT "name", "address" FROM "Users" LIMIT 1)", .sqlServer = R"(SELECT TOP 1 "name", "address" FROM "Users")", + .oracle = R"(SELECT "name", "address" FROM "Users" FETCH FIRST 1 ROWS ONLY)", }); } @@ -128,7 +173,9 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.FieldsForFieldMembers", "[SqlQ checkSqlQueryBuilder([](SqlQueryBuilder& q) { return q.FromTable("Users").Select().Fields().First(); }, QueryExpectations { .sqlite = R"(SELECT "name", "address" FROM "Users" LIMIT 1)", + .postgres = R"(SELECT "name", "address" FROM "Users" LIMIT 1)", .sqlServer = R"(SELECT TOP 1 "name", "address" FROM "Users")", + .oracle = R"(SELECT "name", "address" FROM "Users" FETCH FIRST 1 ROWS ONLY)", }); } @@ -146,7 +193,9 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.FieldsWithBelongsTo", "[SqlQue }, QueryExpectations { .sqlite = R"(SELECT "email", "user" FROM "QueryBuilderTestEmail" LIMIT 1)", + .postgres = R"(SELECT "email", "user" FROM "QueryBuilderTestEmail" LIMIT 1)", .sqlServer = R"(SELECT TOP 1 "email", "user" FROM "QueryBuilderTestEmail")", + .oracle = R"(SELECT "email", "user" FROM "QueryBuilderTestEmail" FETCH FIRST 1 ROWS ONLY)", }); } @@ -165,7 +214,8 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Where.Junctors", "[SqlQueryBui .Count(); // clang-format on }, - QueryExpectations::All(R"SQL(SELECT COUNT(*) FROM "Table" WHERE a AND b OR c AND d AND NOT e)SQL")); + QueryExpectations::All(R"SQL(SELECT COUNT(*) FROM "Table" + WHERE a AND b OR c AND d AND NOT e)SQL")); } TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.WhereIn", "[SqlQueryBuilder]") @@ -173,16 +223,19 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.WhereIn", "[SqlQueryBuilder]") // Check functionality of container overloads for IN checkSqlQueryBuilder( [](SqlQueryBuilder& q) { return q.FromTable("That").Delete().WhereIn("foo", std::vector { 1, 2, 3 }); }, - QueryExpectations::All(R"(DELETE FROM "That" WHERE "foo" IN (1, 2, 3))")); + QueryExpectations::All(R"(DELETE FROM "That" + WHERE "foo" IN (1, 2, 3))")); // Check functionality of an lvalue input range auto const values = std::set { 1, 2, 3 }; checkSqlQueryBuilder([&](SqlQueryBuilder& q) { return q.FromTable("That").Delete().WhereIn("foo", values); }, - QueryExpectations::All(R"(DELETE FROM "That" WHERE "foo" IN (1, 2, 3))")); + QueryExpectations::All(R"(DELETE FROM "That" + WHERE "foo" IN (1, 2, 3))")); // Check functionality of the initializer_list overload for IN checkSqlQueryBuilder([](SqlQueryBuilder& q) { return q.FromTable("That").Delete().WhereIn("foo", { 1, 2, 3 }); }, - QueryExpectations::All(R"(DELETE FROM "That" WHERE "foo" IN (1, 2, 3))")); + QueryExpectations::All(R"(DELETE FROM "That" + WHERE "foo" IN (1, 2, 3))")); } TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Join", "[SqlQueryBuilder]") @@ -192,14 +245,16 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Join", "[SqlQueryBuilder]") return q.FromTable("That").Select().Fields("foo", "bar").InnerJoin("Other", "id", "that_id").All(); }, QueryExpectations::All( - R"(SELECT "foo", "bar" FROM "That" INNER JOIN "Other" ON "Other"."id" = "That"."that_id")")); + R"(SELECT "foo", "bar" FROM "That" + INNER JOIN "Other" ON "Other"."id" = "That"."that_id")")); checkSqlQueryBuilder( [](SqlQueryBuilder& q) { return q.FromTable("That").Select().Fields("foo", "bar").LeftOuterJoin("Other", "id", "that_id").All(); }, QueryExpectations::All( - R"(SELECT "foo", "bar" FROM "That" LEFT OUTER JOIN "Other" ON "Other"."id" = "That"."that_id")")); + R"(SELECT "foo", "bar" FROM "That" + LEFT OUTER JOIN "Other" ON "Other"."id" = "That"."that_id")")); checkSqlQueryBuilder( [](SqlQueryBuilder& q) { @@ -214,8 +269,8 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Join", "[SqlQueryBuilder]") }, QueryExpectations::All("SELECT \"Table_A\".\"foo\", \"Table_A\".\"bar\"," " \"Table_B\".\"that_foo\", \"Table_B\".\"that_id\"" - " FROM \"Table_A\"" - " LEFT OUTER JOIN \"Table_B\" ON \"Table_B\".\"id\" = \"Table_A\".\"that_id\"" + " FROM \"Table_A\"\n" + " LEFT OUTER JOIN \"Table_B\" ON \"Table_B\".\"id\" = \"Table_A\".\"that_id\"\n" " WHERE \"Table_A\".\"foo\" = 42")); checkSqlQueryBuilder( @@ -236,10 +291,9 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Join", "[SqlQueryBuilder]") .All(); }, QueryExpectations::All( - R"(SELECT "Table_A"."foo", "Table_A"."bar", "Table_B"."that_foo", "Table_B"."that_id")" - R"( FROM "Table_A")" - R"( INNER JOIN "Table_B" ON "Table_B"."id" = "Table_A"."that_id" AND "Table_B"."that_foo" = "Table_A"."foo")" - R"( WHERE "Table_A"."foo" = 42)")); + R"(SELECT "Table_A"."foo", "Table_A"."bar", "Table_B"."that_foo", "Table_B"."that_id" FROM "Table_A" + INNER JOIN "Table_B" ON "Table_B"."id" = "Table_A"."that_id" AND "Table_B"."that_foo" = "Table_A"."foo" + WHERE "Table_A"."foo" = 42)")); } TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.SelectAs", "[SqlQueryBuilder]") @@ -289,7 +343,8 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Update", "[SqlQueryBuilder]") [&](SqlQueryBuilder& q) { return q.FromTableAs("Other", "O").Update(&boundValues).Set("foo", 42).Set("bar", "baz").Where("id", 123); }, - QueryExpectations::All(R"(UPDATE "Other" AS "O" SET "foo" = ?, "bar" = ? WHERE "id" = ?)"), + QueryExpectations::All(R"(UPDATE "Other" AS "O" SET "foo" = ?, "bar" = ? + WHERE "id" = ?)"), [&]() { CHECK(boundValues.size() == 3); CHECK(std::get(boundValues[0].value) == 42); @@ -310,7 +365,8 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.Where.Lambda", "[SqlQueryBuild .OrWhere([](auto& q) { return q.Where("b", 2).Where("c", 3); }) .All(); }, - QueryExpectations::All(R"(SELECT "foo" FROM "That" WHERE "a" = 1 OR ("b" = 2 AND "c" = 3))")); + QueryExpectations::All(R"(SELECT "foo" FROM "That" + WHERE "a" = 1 OR ("b" = 2 AND "c" = 3))")); } TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.WhereColumn", "[SqlQueryBuilder]") @@ -319,7 +375,8 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder.WhereColumn", "[SqlQueryBuilde [](SqlQueryBuilder& q) { return q.FromTable("That").Select().Field("foo").WhereColumn("left", "=", "right").All(); }, - QueryExpectations::All(R"(SELECT "foo" FROM "That" WHERE "left" = "right")")); + QueryExpectations::All(R"(SELECT "foo" FROM "That" + WHERE "left" = "right")")); } TEST_CASE_METHOD(SqlTestFixture, "Varying: multiple varying final query types", "[SqlQueryBuilder]") @@ -346,9 +403,8 @@ TEST_CASE_METHOD(SqlTestFixture, "Use SqlQueryBuilder for SqlStatement.ExecuteDi { auto stmt = SqlStatement {}; - bool constexpr quoted = true; - CreateEmployeesTable(stmt, quoted); - FillEmployeesTable(stmt, quoted); + CreateEmployeesTable(stmt); + FillEmployeesTable(stmt); stmt.ExecuteDirect(stmt.Connection().Query("Employees").Select().Fields("FirstName", "LastName").All()); @@ -360,9 +416,8 @@ TEST_CASE_METHOD(SqlTestFixture, "Use SqlQueryBuilder for SqlStatement.Prepare", { auto stmt = SqlStatement {}; - bool constexpr quoted = true; - CreateEmployeesTable(stmt, quoted); - FillEmployeesTable(stmt, quoted); + CreateEmployeesTable(stmt); + FillEmployeesTable(stmt); std::vector inputBindings; @@ -387,8 +442,7 @@ TEST_CASE_METHOD(SqlTestFixture, "Use SqlQueryBuilder for SqlStatement.Prepare: { auto stmt = SqlStatement {}; - bool constexpr quoted = true; - CreateLargeTable(stmt, quoted); + CreateLargeTable(stmt); // Prepare INSERT query auto insertQuery = stmt.Connection().Query("LargeTable").Insert(nullptr /* no auto-fill */); @@ -501,3 +555,279 @@ TEST_CASE_METHOD(SqlTestFixture, "SqlQueryBuilder: sub select with WhereIn", "[S REQUIRE(!stmt.FetchRow()); } + +TEST_CASE_METHOD(SqlTestFixture, "DropTable", "[SqlQueryBuilder][Migration]") +{ + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.DropTable("Table"); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql( + DROP TABLE "Table"; + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "CreateTable with Column", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.CreateTable("Test").Column("column", Varchar { 255 }); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(CREATE TABLE "Test" ( + "column" VARCHAR(255) + ); + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "CreateTable with RequiredColumn", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.CreateTable("Test").RequiredColumn("column", Varchar { 255 }); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(CREATE TABLE "Test" ( + "column" VARCHAR(255) NOT NULL + ); + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "CreateTable with Column: Guid", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.CreateTable("Test").RequiredColumn("column", Guid {}); + return migration.GetPlan(); + }, + QueryExpectations { + .sqlite = R"sql(CREATE TABLE "Test" ( + "column" GUID NOT NULL + ); + )sql", + .postgres = R"sql(CREATE TABLE "Test" ( + "column" UUID NOT NULL + ); + )sql", + .sqlServer = R"sql(CREATE TABLE "Test" ( + "column" UNIQUEIDENTIFIER NOT NULL + ); + )sql", + .oracle = R"sql(CREATE TABLE "Test" ( + "column" RAW(16) NOT NULL + ); + )sql", + }); +} + +TEST_CASE_METHOD(SqlTestFixture, "CreateTable with PrimaryKey", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.CreateTable("Test").PrimaryKey("pk", Integer {}); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(CREATE TABLE "Test" ( + "pk" INTEGER NOT NULL PRIMARY KEY + ); + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "CreateTable with PrimaryKeyWithAutoIncrement", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.CreateTable("Test").PrimaryKeyWithAutoIncrement("pk"); + return migration.GetPlan(); + }, + QueryExpectations { + .sqlite = R"sql(CREATE TABLE "Test" ( + "pk" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT + ); + )sql", + .postgres = R"sql(CREATE TABLE "Test" ( + "pk" SERIAL NOT NULL PRIMARY KEY + ); + )sql", + .sqlServer = R"sql(CREATE TABLE "Test" ( + "pk" BIGINT NOT NULL IDENTITY(1,1) PRIMARY KEY + ); + )sql", + .oracle = R"sql(CREATE TABLE "Test" ( + "pk" NUMBER(19,0) NOT NULL PRIMARY KEY + ); + )sql", + }); +} +TEST_CASE_METHOD(SqlTestFixture, "CreateTable with Index", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.CreateTable("Table").RequiredColumn("column", Integer {}).Index(); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(CREATE TABLE "Table" ( + "column" INTEGER NOT NULL + ); + CREATE INDEX "Table_column_index" ON "Table"("column"); + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "CreateTable complex demo", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + // clang-format off + auto migration = q.Migration(); + migration.CreateTable("Test") + .PrimaryKeyWithAutoIncrement("a", Bigint {}) + .RequiredColumn("b", Varchar { 32 }).Unique() + .Column("c", DateTime {}).Index() + .Column("d", Varchar { 255 }).UniqueIndex();; + return migration.GetPlan(); + // clang-format on + }, + QueryExpectations { + .sqlite = R"sql( + CREATE TABLE "Test" ( + "a" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + "b" VARCHAR(32) NOT NULL UNIQUE, + "c" DATETIME, + "d" VARCHAR(255) + ); + CREATE INDEX "Test_c_index" ON "Test"("c"); + CREATE UNIQUE INDEX "Test_d_index" ON "Test"("d"); + )sql", + .postgres = R"sql( + CREATE TABLE "Test" ( + "a" SERIAL NOT NULL PRIMARY KEY, + "b" VARCHAR(32) NOT NULL UNIQUE, + "c" DATETIME, + "d" VARCHAR(255) + ); + CREATE INDEX "Test_c_index" ON "Test"("c"); + CREATE UNIQUE INDEX "Test_d_index" ON "Test"("d"); + )sql", + .sqlServer = R"sql( + CREATE TABLE "Test" ( + "a" BIGINT NOT NULL IDENTITY(1,1) PRIMARY KEY, + "b" VARCHAR(32) NOT NULL UNIQUE, + "c" DATETIME, + "d" VARCHAR(255) + ); + CREATE INDEX "Test_c_index" ON "Test"("c"); + CREATE UNIQUE INDEX "Test_d_index" ON "Test"("d"); + )sql", + .oracle = R"sql( + CREATE TABLE "Test" ( + "a" NUMBER GENERATED BY DEFAULT ON NULL AS IDENTITY PRIMARY KEY + "b" VARCHAR2(32 CHAR) NOT NULL UNIQUE, + "c" DATETIME, + "d" VARCHAR2(255 CHAR) + ); + CREATE INDEX "Test_c_index" ON "Test"("c"); + CREATE UNIQUE INDEX "Test_d_index" ON "Test"("d"); + )sql", + }); +} + +TEST_CASE_METHOD(SqlTestFixture, "AlterTable AddColumn", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.AlterTable("Table").AddColumn("column", Integer {}); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(ALTER TABLE "Table" ADD COLUMN "column" INTEGER; + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "AlterTable multiple AddColumn calls", "[SqlQueryBuilder][Migration]") +{ + using namespace SqlColumnTypeDefinitions; + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.AlterTable("Table").AddColumn("column", Integer {}).AddColumn("column2", Varchar { 255 }); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(ALTER TABLE "Table" ADD COLUMN "column" INTEGER; + ALTER TABLE "Table" ADD COLUMN "column2" VARCHAR(255); + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "AlterTable RenameColumn", "[SqlQueryBuilder][Migration]") +{ + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.AlterTable("Table").RenameColumn("old", "new"); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(ALTER TABLE "Table" RENAME COLUMN "old" TO "new"; + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "AlterTable RenameTo", "[SqlQueryBuilder][Migration]") +{ + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.AlterTable("Table").RenameTo("NewTable"); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(ALTER TABLE "Table" RENAME TO "NewTable"; + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "AlterTable AddIndex", "[SqlQueryBuilder][Migration]") +{ + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.AlterTable("Table").AddIndex("column"); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(CREATE INDEX "Table_column_index" ON "Table"("column"); + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "AlterTable AddUniqueIndex", "[SqlQueryBuilder][Migration]") +{ + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.AlterTable("Table").AddUniqueIndex("column"); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(CREATE UNIQUE INDEX "Table_column_index" ON "Table"("column"); + )sql")); +} + +TEST_CASE_METHOD(SqlTestFixture, "AlterTable DropIndex", "[SqlQueryBuilder][Migration]") +{ + checkSqlQueryBuilder( + [](SqlQueryBuilder& q) { + auto migration = q.Migration(); + migration.AlterTable("Table").DropIndex("column"); + return migration.GetPlan(); + }, + QueryExpectations::All(R"sql(DROP INDEX "Table_column_index";)sql")); +} diff --git a/src/tests/Utils.hpp b/src/tests/Utils.hpp index 7e2326ee..ad6c2836 100644 --- a/src/tests/Utils.hpp +++ b/src/tests/Utils.hpp @@ -459,57 +459,37 @@ inline std::ostream& operator<<(std::ostream& os, SqlFixedString con // }}} -inline void CreateEmployeesTable(SqlStatement& stmt, - bool quoted = false, - std::source_location sourceLocation = std::source_location::current()) +inline void CreateEmployeesTable(SqlStatement& stmt, std::source_location location = std::source_location::current()) { - if (quoted) - stmt.ExecuteDirect(std::format(R"SQL(CREATE TABLE "Employees" ( - "EmployeeID" {}, - "FirstName" VARCHAR(50) NOT NULL, - "LastName" VARCHAR(50), - "Salary" INT NOT NULL - ); - )SQL", - stmt.Connection().Traits().PrimaryKeyAutoIncrement), - sourceLocation); - else - stmt.ExecuteDirect(std::format(R"SQL(CREATE TABLE Employees ( - EmployeeID {}, - FirstName VARCHAR(50) NOT NULL, - LastName VARCHAR(50), - Salary INT NOT NULL - ); - )SQL", - stmt.Connection().Traits().PrimaryKeyAutoIncrement), - sourceLocation); + stmt.MigrateDirect( + [](SqlMigrationQueryBuilder& migration) { + migration.CreateTable("Employees") + .PrimaryKeyWithAutoIncrement("EmployeeID") + .RequiredColumn("FirstName", SqlColumnTypeDefinitions::Varchar { 50 }) + .Column("LastName", SqlColumnTypeDefinitions::Varchar { 50 }) + .RequiredColumn("Salary", SqlColumnTypeDefinitions::Integer {}); + }, + location); } -inline void CreateLargeTable(SqlStatement& stmt, bool quote = false) +inline void CreateLargeTable(SqlStatement& stmt) { - std::stringstream sqlQueryStr; - auto const quoted = [quote](auto&& str) { - return quote ? std::format("\"{}\"", str) : str; - }; - sqlQueryStr << "CREATE TABLE " << quoted("LargeTable") << " (\n"; - for (char c = 'A'; c <= 'Z'; ++c) - { - sqlQueryStr << " " << quoted(std::string(1, c)) << " VARCHAR(50) NULL"; - if (c != 'Z') - sqlQueryStr << ","; - sqlQueryStr << "\n"; - } - sqlQueryStr << ")\n"; - - stmt.ExecuteDirect(sqlQueryStr.str()); + stmt.MigrateDirect([](SqlMigrationQueryBuilder& migration) { + auto table = migration.CreateTable("LargeTable"); + for (char c = 'A'; c <= 'Z'; ++c) + { + table.Column(std::string(1, c), SqlColumnTypeDefinitions::Varchar { 50 }); + } + }); } -inline void FillEmployeesTable(SqlStatement& stmt, bool quoted = false) +inline void FillEmployeesTable(SqlStatement& stmt) { - if (quoted) - stmt.Prepare(R"(INSERT INTO "Employees" ("FirstName", "LastName", "Salary") VALUES (?, ?, ?))"); - else - stmt.Prepare("INSERT INTO Employees (FirstName, LastName, Salary) VALUES (?, ?, ?)"); + stmt.Prepare(stmt.Query("Employees") + .Insert() + .Set("FirstName", SqlWildcard) + .Set("LastName", SqlWildcard) + .Set("Salary", SqlWildcard)); stmt.Execute("Alice", "Smith", 50'000); stmt.Execute("Bob", "Johnson", 60'000); stmt.Execute("Charlie", "Brown", 70'000);