diff --git a/backends/apple/coreml/runtime/kvstore/database.cpp b/backends/apple/coreml/runtime/kvstore/database.cpp index 09acad561d2..7fb9ba5e22c 100644 --- a/backends/apple/coreml/runtime/kvstore/database.cpp +++ b/backends/apple/coreml/runtime/kvstore/database.cpp @@ -8,6 +8,7 @@ #include +#include // @nocommit #include namespace { @@ -39,6 +40,7 @@ std::string toString(Database::SynchronousMode mode) { return "OFF"; } } + return "FULL"; // safe default } /// Returns the sqlite statement for a specified transaction behavior. @@ -54,6 +56,7 @@ std::string getTransactionStatement(Database::TransactionBehavior behavior) { return "BEGIN EXCLUSIVE"; } } + return "BEGIN DEFERRED"; // safe default } } // namespace @@ -90,10 +93,6 @@ int Database::OpenOptions::get_sqlite_flags() const noexcept { flags |= SQLITE_OPEN_SHAREDCACHE; } - if (is_shared_cache_option_enabled()) { - flags |= SQLITE_OPEN_SHAREDCACHE; - } - if (is_uri_option_enabled()) { flags |= SQLITE_OPEN_URI; } @@ -101,36 +100,50 @@ int Database::OpenOptions::get_sqlite_flags() const noexcept { return flags; } -bool Database::open(OpenOptions options, SynchronousMode mode, int busy_timeout_ms, std::error_code& error) noexcept { - sqlite3* handle = nullptr; - const int status = sqlite3_open_v2(file_path_.c_str(), &handle, options.get_sqlite_flags(), nullptr); - sqlite_database_.reset(handle); +bool Database::open(std::error_code& error) const { + sqlite3* tmp = nullptr; + const int status = sqlite3_open_v2(file_path_.c_str(), &tmp, open_options_.get_sqlite_flags(), nullptr); + if (!process_sqlite_status(status, error)) { + if (tmp) + sqlite3_close_v2(tmp); // ensure no leaked/half-open handle return false; } - if (!set_busy_timeout(busy_timeout_ms, error)) { - return false; - } + // Now we know it's good: install the connection + sqlite_database_.reset(tmp); - if (!execute("pragma journal_mode = WAL", error)) { + // Re-apply connection configuration + if (!set_busy_timeout(busy_timeout_ms_, error)) { + sqlite_database_.reset(nullptr); return false; } - if (!execute("pragma auto_vacuum = FULL", error)) { - return false; + const bool ro = open_options_.is_read_only_option_enabled(); + const bool in_mem = (file_path_ == ":memory:"); + if (!ro && !in_mem) { + if (!execute("pragma journal_mode = WAL", error)) { + sqlite_database_.reset(nullptr); + return false; + } + if (!execute("pragma auto_vacuum = FULL", error)) { + sqlite_database_.reset(nullptr); + return false; + } } - if (!execute("pragma synchronous = " + toString(mode), error)) { + if (!execute(std::string("pragma synchronous = ") + toString(synchronous_mode_), error)) { + sqlite_database_.reset(nullptr); return false; } + error.clear(); // clear the error return true; } bool Database::is_open() const noexcept { return sqlite_database_ != nullptr; } -bool Database::table_exists(const std::string& tableName, std::error_code& error) const noexcept { +bool Database::table_exists(const std::string& tableName, std::error_code& error) const { auto statement = prepare_statement("SELECT COUNT(*) FROM sqlite_master WHERE TYPE='table' AND NAME=?", error); if (!statement) { return false; @@ -148,16 +161,19 @@ bool Database::table_exists(const std::string& tableName, std::error_code& error if (error) { return false; } - - return (std::get(value) == 1); + if (auto p = std::get_if(&value)) { + return (*p == 1); + } + error = make_error_code(std::errc::invalid_argument); + return false; } -bool Database::drop_table(const std::string& tableName, std::error_code& error) const noexcept { +bool Database::drop_table(const std::string& tableName, std::error_code& error) const { std::string statement = "DROP TABLE IF EXISTS " + tableName; return execute(statement, error); } -int64_t Database::get_row_count(const std::string& tableName, std::error_code& error) const noexcept { +int64_t Database::get_row_count(const std::string& tableName, std::error_code& error) const { auto statement = prepare_statement("SELECT COUNT(*) FROM " + tableName, error); if (!statement) { return -1; @@ -168,76 +184,306 @@ int64_t Database::get_row_count(const std::string& tableName, std::error_code& e } auto value = statement->get_column_value(0, error); - return std::get(value); + if (auto p = std::get_if(&value)) + return *p; + error = make_error_code(std::errc::invalid_argument); + return -1; } bool Database::set_busy_timeout(int busy_timeout_ms, std::error_code& error) const noexcept { - const int status = sqlite3_busy_timeout(get_underlying_database(), busy_timeout_ms); + auto* db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return false; + } + const int status = sqlite3_busy_timeout(db, busy_timeout_ms); return process_sqlite_status(status, error); } -bool Database::execute(const std::string& statements, std::error_code& error) const noexcept { - const int status = sqlite3_exec(get_underlying_database(), statements.c_str(), nullptr, nullptr, nullptr); - return process_sqlite_status(status, error); +bool Database::execute(const std::string& statements, std::error_code& error) const { + return execute_and_maybe_retry(statements, error, true); } -int Database::get_updated_row_count() const noexcept { return sqlite3_changes(get_underlying_database()); } +bool Database::execute_and_maybe_retry(const std::string& statements, std::error_code& error, bool retry) const { -std::string Database::get_last_error_message() const noexcept { return sqlite3_errmsg(get_underlying_database()); } + std::cout << "Starting to execute statements: " << statements << std::endl; + auto* db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return false; + } + const int status = sqlite3_exec(db, statements.c_str(), nullptr, nullptr, nullptr); + if (process_sqlite_status(status, error)) { + std::cout << "Execute succeeded on first attempt" << std::endl; + std::cout << "Returning true with error code: " << error << std::endl; + error.clear(); // clear the error + return true; + } + + if (!retry) { + std::cout << "Not retrying. Returning false." << std::endl; + return false; + } + + std::cout << "Execute failed in SQL. Trying to reopen." << std::endl; + + // Only attempt a reopen if we're not inside a transaction + if (!in_transaction()) { + std::cout << "Not in transaction. Attempting to reopen." << std::endl; + db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return false; + } + const int err = sqlite3_errcode(db); + const int xerr = sqlite3_extended_errcode(db); + if (is_recoverable_connection_error(err, xerr)) { + std::cout << "Recoverable error code" << std::endl; + std::error_code reopen_ec; + if (reopen(reopen_ec)) { + std::cout << "Reopen succeeded. Trying again." << std::endl; + std::error_code retry_ec; + db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return false; + } + int retry_status = sqlite3_exec(db, statements.c_str(), nullptr, nullptr, nullptr); + if (process_sqlite_status(retry_status, retry_ec)) { + std::cout << "Retry succeeded." << std::endl; + error.clear(); // clear the error + return true; + } + std::cout << "Retry failed." << std::endl; + error = retry_ec; // surface the retry failure + } else { + std::cout << "Reopen failed." << std::endl; + error = reopen_ec; // surface the reopen failure + } + } + } + std::cout << "Reopen failed. Returning false." << std::endl; + return false; +} + +int Database::get_updated_row_count() const noexcept { + auto db = get_underlying_database(); + if (!db) { + return -1; + } + return sqlite3_changes(db); +} + +std::string Database::get_last_error_message() const noexcept { + auto db = get_underlying_database(); + if (!db) { + return ""; + } + return sqlite3_errmsg(db); +} std::unique_ptr Database::prepare_statement(const std::string& statement, std::error_code& error) const noexcept { - sqlite3_stmt* handle = getPreparedStatement(get_underlying_database(), statement, error); - return std::make_unique(std::unique_ptr(handle)); + std::cout << "Preparing statement: " << statement << std::endl; + auto db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return nullptr; + } + sqlite3_stmt* handle = getPreparedStatement(db, statement, error); + if (handle) { + std::cout << "Returning prepared statement." << std::endl; + error.clear(); + return std::make_unique(std::unique_ptr(handle)); + } + std::cout << "Failed to prepare statement. Retrying. The error code was: " << error << std::endl; + std::cout << " xerr=" << sqlite3_extended_errcode(get_underlying_database()) << std::endl; + std::cout << " msg=" << sqlite3_errmsg(get_underlying_database()) << "\n"; + if (!in_transaction()) { + std::cout << "Not in transaction. Attempting to reopen." << std::endl; + db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return nullptr; + } + const int err = sqlite3_errcode(db); + const int xerr = sqlite3_extended_errcode(db); + if (is_recoverable_connection_error(err, xerr)) { + std::error_code reopen_ec; + if (reopen(reopen_ec)) { + // try again + db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return nullptr; + } + sqlite3_stmt* h2 = getPreparedStatement(db, statement, error); + if (h2) { + return std::make_unique(std::unique_ptr(h2)); + } + // else `error` is already set by getPreparedStatement + } else { + error = reopen_ec; + } + } + } + std::cout << "Returning nullptr." << std::endl; + return nullptr; } int64_t Database::get_last_inserted_row_id() const noexcept { - return sqlite3_last_insert_rowid(get_underlying_database()); + auto db = get_underlying_database(); + if (!db) { + return -1; + } + return sqlite3_last_insert_rowid(db); } std::error_code Database::get_last_error_code() const noexcept { - int code = sqlite3_errcode(get_underlying_database()); - return static_cast(code); + auto db = get_underlying_database(); + if (!db) { + return std::make_error_code(std::errc::bad_file_descriptor); + } + int code = sqlite3_errcode(db); + return make_error_code(static_cast(code)); } std::error_code Database::get_last_extended_error_code() const noexcept { - int code = sqlite3_extended_errcode(get_underlying_database()); - return static_cast(code); + auto db = get_underlying_database(); + if (!db) { + return std::make_error_code(std::errc::bad_file_descriptor); + } + int code = sqlite3_extended_errcode(db); + return make_error_code(static_cast(code)); } -bool Database::begin_transaction(TransactionBehavior behavior, std::error_code& error) const noexcept { +bool Database::begin_transaction(TransactionBehavior behavior, std::error_code& error) const { return execute(getTransactionStatement(behavior), error); } -bool Database::commit_transaction(std::error_code& error) const noexcept { - return execute("COMMIT TRANSACTION", error); -} +bool Database::commit_transaction(std::error_code& error) const { return execute("COMMIT TRANSACTION", error); } + +bool Database::rollback_transaction(std::error_code& error) const { + auto* db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return false; + } + + // If no txn is active, treat as success. + if (sqlite3_get_autocommit(db) != 0) { + error.clear(); + return true; + } + + if (execute_and_maybe_retry("ROLLBACK TRANSACTION", error, false)) { + error.clear(); + return true; + } + + // Recovery: force-close (implicit rollback) and reopen + // Optional: aggressively finalize any live statements + db = get_underlying_database(); + if (!db) { + error = std::make_error_code(std::errc::bad_file_descriptor); + return false; + } + sqlite3_stmt* s = nullptr; + while ((s = sqlite3_next_stmt(db, s)) != nullptr) + sqlite3_finalize(s); + + sqlite_database_.reset(nullptr); // implicit rollback at close + if (open(error)) { + error.clear(); + return true; + } -bool Database::rollback_transaction(std::error_code& error) const noexcept { - return execute("ROLLBACK TRANSACTION", error); + return false; // 'error' already set } -bool Database::transaction(const std::function& fn, - TransactionBehavior behavior, - std::error_code& error) noexcept { +bool Database::transaction(const std::function& fn, TransactionBehavior behavior, std::error_code& error) { + std::cout << "Starting a new transaction." << std::endl; if (!begin_transaction(behavior, error)) { return false; } - bool status = fn(); + std::cout << "Status of transaction: " << status << std::endl; if (status) { + std::cout << "Committing transaction." << std::endl; return commit_transaction(error); } else { + std::cout << "Rolling back transaction." << std::endl; rollback_transaction(error); return false; } } +bool Database::in_transaction() const noexcept { + auto* db = get_underlying_database(); + return db && sqlite3_get_autocommit(db) == 0; +} + +bool Database::is_recoverable_connection_error(int err, int xerr) noexcept { + const int primary = (err & 0xFF); + + // Extended codes that clearly indicate a stale/broken handle or path change + switch (xerr) { + case SQLITE_READONLY_DBMOVED: + case SQLITE_IOERR_READ: + case SQLITE_IOERR_WRITE: + case SQLITE_IOERR_FSYNC: + case SQLITE_IOERR_LOCK: + case SQLITE_CANTOPEN_DIRTYWAL: + case SQLITE_CANTOPEN_NOTEMPDIR: + return true; + case SQLITE_BUSY_RECOVERY: + case SQLITE_BUSY_SNAPSHOT: + case SQLITE_LOCKED: + return false; + default: + break; + } + + // Primary classes that are generally reopenable + switch (primary) { + case SQLITE_IOERR: + case SQLITE_NOTADB: + case SQLITE_CANTOPEN: + return true; + default: + return false; + } + + return false; +} + +bool Database::reopen(std::error_code& error) const { + std::cout << "Inside reopen." << std::endl; + if (in_transaction()) { + std::cout << "In transaction. Returning false." << std::endl; + error = std::make_error_code(std::errc::operation_in_progress); + return false; + } + if (auto* db = get_underlying_database()) { + // Refuse to reopen if any live statements exist. + if (sqlite3_next_stmt(db, nullptr) != nullptr) { + error = std::make_error_code(std::errc::device_or_resource_busy); + return false; + } + } + + // Close current (best-effort) + std::cout << "Closing current." << std::endl; + sqlite_database_.reset(nullptr); + std::cout << "Calling open." << std::endl; + return open(error); +} + std::shared_ptr Database::make_inmemory(SynchronousMode mode, int busy_timeout_ms, std::error_code& error) { - auto database = std::make_shared(":memory:"); OpenOptions options; options.set_read_write_option(true); - if (database->open(options, mode, busy_timeout_ms, error)) { + auto database = std::make_shared(":memory:", options, mode, busy_timeout_ms); + if (database->open(error)) { return database; } @@ -249,8 +495,8 @@ std::shared_ptr Database::make(const std::string& filePath, SynchronousMode mode, int busy_timeout_ms, std::error_code& error) { - auto database = std::make_shared(filePath); - if (database->open(options, mode, busy_timeout_ms, error)) { + auto database = std::make_shared(filePath, options, mode, busy_timeout_ms); + if (database->open(error)) { return database; } diff --git a/backends/apple/coreml/runtime/kvstore/database.hpp b/backends/apple/coreml/runtime/kvstore/database.hpp index 61dc66a0ff9..99286b56d88 100644 --- a/backends/apple/coreml/runtime/kvstore/database.hpp +++ b/backends/apple/coreml/runtime/kvstore/database.hpp @@ -24,7 +24,7 @@ namespace sqlite { struct DatabaseDeleter { inline void operator()(sqlite3* handle) { if (handle) { - sqlite3_close(handle); + sqlite3_close_v2(handle); } } }; @@ -41,89 +41,89 @@ class Database { inline void set_read_only_option(bool enable) noexcept { flags_[0] = enable; } - + /// Returns `true` if read-only option is enabled otherwise `false`. inline bool is_read_only_option_enabled() const noexcept { return flags_[0]; } - + /// Corresponds to `SQLITE_OPEN_READWRITE` flag, when set the database will be opened in read and write mode. inline void set_read_write_option(bool enable) noexcept { flags_[1] = enable; } - + /// Returns `true` if read and write option is enabled otherwise `false`. inline bool is_read_write_option_enabled() const noexcept { return flags_[1]; } - + /// Corresponds to `SQLITE_OPEN_CREATE` flag, when set the database will be created if it does not exist. inline void set_create_option(bool enable) noexcept { flags_[2] = enable; } - + /// Returns `true` if create option is enabled otherwise `false`. inline bool is_create_option_enabled() const noexcept { return flags_[2]; } - + /// Corresponds to `SQLITE_OPEN_MEMORY` flag, when set the database will be opened as in-memory database. inline void set_memory_option(bool enable) noexcept { flags_[3] = enable; } - + /// Returns `true` if memory option is enabled otherwise `false`. inline bool is_memory_option_enabled() const noexcept { return flags_[3]; } - + /// Corresponds to `SQLITE_OPEN_NOMUTEX` flag, when set the database connection will use the "multi-thread" threading mode. inline void set_no_mutex_option(bool enable) noexcept { flags_[4] = enable; } - + /// Returns `true` if no mutex option is enabled otherwise `false`. inline bool is_no_mutex_option_enabled() const noexcept { return flags_[4]; } - + /// Corresponds to `SQLITE_OPEN_FULLMUTEX` flag, when set the database connection will use the "serialized" threading mode. inline void set_full_mutex_option(bool enable) noexcept { flags_[5] = enable; } - + /// Returns `true` if full mutex option is enabled otherwise `false`. inline bool is_full_mutex_option_enabled() const noexcept { return flags_[5]; } - + /// Corresponds to `SQLITE_OPEN_SHAREDCACHE` flag, when set the database will be opened with shared cache enabled. inline void set_shared_cache_option(bool enable) noexcept { flags_[6] = enable; } - + /// Returns `true` if shared cache option is enabled otherwise `false`. inline bool is_shared_cache_option_enabled() const noexcept { return flags_[6]; } - + /// Corresponds to `SQLITE_OPEN_URI` flag, when set the filename can be interpreted as a URI. inline void set_uri_option(bool enable) noexcept { flags_[7] = enable; } - + /// Returns `true` if URI option is enabled otherwise `false`. inline bool is_uri_option_enabled() const noexcept { return flags_[7]; } - + /// Returns the sqlite flags that can be used to open a sqlite database from the set options. int get_sqlite_flags() const noexcept; - + private: std::bitset<8> flags_; }; - + /// Represents sqlite synchronous flag. enum class SynchronousMode: uint8_t { Extra = 0, @@ -131,22 +131,25 @@ class Database { Normal, Off, }; - + /// Represents the behavior of a sqlite transaction enum class TransactionBehavior: uint8_t { Deferred = 0, Immediate, Exclusive, }; - - /// Constructs a database from a file path. - Database(const std::string& filePath) noexcept - :file_path_(filePath) + + /// Constructs a database from a file path and options + Database(const std::string& filePath, OpenOptions open_options, SynchronousMode synchronous_mode, int busy_timeout_ms) noexcept + :file_path_(filePath), + open_options_(open_options), + synchronous_mode_(synchronous_mode), + busy_timeout_ms_(busy_timeout_ms) {} - + Database(Database const&) = delete; Database& operator=(Database const&) = delete; - + /// Opens a database /// /// @param options The options for opening the database. @@ -154,62 +157,61 @@ class Database { /// @param busy_timeout_ms The busy timeout interval in milliseconds. /// @param error On failure, error is populated with the failure reason. /// @retval `true` if the database is opened otherwise `false`. - bool open(OpenOptions options, - SynchronousMode mode, - int busy_timeout_ms, - std::error_code& error) noexcept; - + bool open(std::error_code& error) const; + /// Returns `true` is the database is opened otherwise `false`. bool is_open() const noexcept; - + /// Check if a table exists with the specified name. /// /// @param tableName The table name. /// @param error On failure, error is populated with the failure reason. /// @retval `true` if the table exists otherwise `false`. - bool table_exists(const std::string& tableName, std::error_code& error) const noexcept; - + bool table_exists(const std::string& tableName, std::error_code& error) const; + /// Drops a table with the specified name. /// /// @param tableName The table name. /// @param error On failure, error is populated with the failure reason. /// @retval `true` if the table is dropped otherwise `false`. - bool drop_table(const std::string& tableName, std::error_code& error) const noexcept; - + bool drop_table(const std::string& tableName, std::error_code& error) const; + /// Returns the number of rows in the table. /// /// @param tableName The table name. /// @param error On failure, error is populated with the failure reason. /// @retval The number of rows in the table. - int64_t get_row_count(const std::string& tableName, std::error_code& error) const noexcept; - + int64_t get_row_count(const std::string& tableName, std::error_code& error) const; + /// Executes the provided statements. /// /// @param statements The statements to execute. /// @param error On failure, error is populated with the failure reason. /// @retval `true` if the execution succeeded otherwise `false`. - bool execute(const std::string& statements, std::error_code& error) const noexcept; - + bool execute(const std::string& statements, std::error_code& error) const; + + bool execute_and_maybe_retry(const std::string& statements, std::error_code& error, bool retry) const; + /// Returns the number of rows updated by the last statement. int get_updated_row_count() const noexcept; - + /// Returns the error message of the last failed sqlite call. std::string get_last_error_message() const noexcept; - + /// Returns the error code of the last failed sqlite call. std::error_code get_last_error_code() const noexcept; - + /// Returns the extended error code of the last failed sqlite call. std::error_code get_last_extended_error_code() const noexcept; - + /// Returns the value of the last inserted row id. int64_t get_last_inserted_row_id() const noexcept; - + /// Returns the file path that was used to create the database. std::string_view file_path() const noexcept { return file_path_; } - + /// Compiles the provided statement and returns it. /// /// @param statement The statement to be compiled. @@ -217,7 +219,7 @@ class Database { /// @retval The compiled statement. std::unique_ptr prepare_statement(const std::string& statement, std::error_code& error) const noexcept; - + /// Executes the provided function inside a transaction. /// /// The transaction is committed only if the provided function returns `true` otherwise the transaction is rolled-back. @@ -228,8 +230,8 @@ class Database { /// @retval `true` if the transaction is committed otherwise `false`. bool transaction(const std::function& fn, TransactionBehavior behavior, - std::error_code& error) noexcept; - + std::error_code& error); + /// Opens an in-memory database. /// /// @param mode The synchronous mode. @@ -239,7 +241,7 @@ class Database { static std::shared_ptr make_inmemory(SynchronousMode mode, int busy_timeout_ms, std::error_code& error); - + /// Creates and opens a database at the specified path. /// /// @param filePath The file path of the database. @@ -253,27 +255,34 @@ class Database { SynchronousMode mode, int busy_timeout_ms, std::error_code& error); - + private: /// Returns the internal sqlite database. inline sqlite3 *get_underlying_database() const noexcept { return sqlite_database_.get(); } - + /// Registers an internal busy handler that keeps attempting to acquire a busy lock until the total specified time has passed. bool set_busy_timeout(int busy_timeout_ms, std::error_code& error) const noexcept; - + /// Begins an explicit transaction with the specified behavior. - bool begin_transaction(TransactionBehavior behavior, std::error_code& error) const noexcept; - + bool begin_transaction(TransactionBehavior behavior, std::error_code& error) const; + /// Commits the last open transaction. - bool commit_transaction(std::error_code& error) const noexcept; - + bool commit_transaction(std::error_code& error) const; + /// Rollbacks the last open transaction. - bool rollback_transaction(std::error_code& error) const noexcept; - + bool rollback_transaction(std::error_code& error) const; + + bool reopen(std::error_code& error) const; + bool in_transaction() const noexcept; + static bool is_recoverable_connection_error(int err, int xerr) noexcept; + std::string file_path_; - std::unique_ptr sqlite_database_; + OpenOptions open_options_; + SynchronousMode synchronous_mode_; + int busy_timeout_ms_; + mutable std::unique_ptr sqlite_database_; }; } // namespace sqlite diff --git a/backends/apple/coreml/runtime/kvstore/key_value_store.cpp b/backends/apple/coreml/runtime/kvstore/key_value_store.cpp index 4a7a491236b..b0d85cd05dc 100644 --- a/backends/apple/coreml/runtime/kvstore/key_value_store.cpp +++ b/backends/apple/coreml/runtime/kvstore/key_value_store.cpp @@ -42,10 +42,10 @@ get_create_store_statement(std::string_view store_name, StorageType key_storage_ ss << "CREATE TABLE IF NOT EXISTS "; ss << store_name << " "; ss << "("; - ss << "ENTRY_KEY " << to_string(key_storage_type) << "PRIMARY KEY UNIQUE, "; + ss << "ENTRY_KEY " << to_string(key_storage_type) << " PRIMARY KEY UNIQUE, "; ss << "ENTRY_VALUE " << to_string(value_storage_type) << ", "; - ss << "ENTRY_ACCESS_COUNT " << to_string(StorageType::Integer) << ", "; - ss << "ENTRY_ACCESS_TIME " << to_string(StorageType::Integer); + ss << "ENTRY_ACCESS_COUNT INTEGER NOT NULL DEFAULT 0, "; + ss << "ENTRY_ACCESS_TIME INTEGER NOT NULL DEFAULT 0"; ss << ")"; return ss.str(); @@ -53,7 +53,8 @@ get_create_store_statement(std::string_view store_name, StorageType key_storage_ std::string get_create_index_statement(std::string_view store_name, std::string_view column_name) { std::stringstream ss; - ss << "CREATE INDEX IF NOT EXISTS " << column_name << "_INDEX" << " ON " << store_name << "(" << column_name << ")"; + ss << "CREATE INDEX IF NOT EXISTS " << store_name << "_" << column_name << "_INDEX" << " ON " << store_name << "(" + << column_name << ")"; return ss.str(); } @@ -89,8 +90,14 @@ std::string get_key_count_statement(std::string_view store_name) { std::string get_update_entry_access_statement(std::string_view store_name) { std::stringstream ss; - ss << "UPDATE " << store_name << " SET ENTRY_ACCESS_COUNT = ?, ENTRY_ACCESS_TIME = ? WHERE ENTRY_KEY = ?"; + ss << "UPDATE " << store_name + << " SET ENTRY_ACCESS_COUNT = ENTRY_ACCESS_COUNT + 1, ENTRY_ACCESS_TIME = ? WHERE ENTRY_KEY = ?"; + return ss.str(); +} +static std::string get_exists_statement(std::string_view store) { + std::stringstream ss; + ss << "SELECT 1 FROM " << store << " WHERE ENTRY_KEY = ? LIMIT 1"; return ss.str(); } @@ -159,12 +166,18 @@ bool execute(Database* database, int64_t get_last_access_time(Database* database, std::string_view storeName, std::error_code& error) { int64_t latestAccessTime = 0; auto statement = get_keys_sorted_by_column_statement(storeName, kAccessTimeColumnName, SortOrder::Descending); - std::function fn = [&latestAccessTime](const UnOwnedValue& value) { - latestAccessTime = std::get(value); - return false; - }; - return execute(database, statement, 1, fn, error); + bool ok = execute( + database, + statement, + /*columnIndex=*/2, + [&](const UnOwnedValue& v) { + latestAccessTime = std::get(v); + return false; // stop after first row + }, + error); + + return (ok && !error) ? latestAccessTime : 0; } } // namespace @@ -172,80 +185,32 @@ int64_t get_last_access_time(Database* database, std::string_view storeName, std namespace executorchcoreml { namespace sqlite { -bool KeyValueStoreImpl::init(std::error_code& error) noexcept { - if (!database_->execute(get_create_store_statement(name_, get_key_storage_type_, get_value_storage_type_), error)) { - return false; - } - - if (!database_->execute(get_create_index_statement(name_, kAccessCountColumnName), error)) { - return false; - } - - if (!database_->execute(get_create_index_statement(name_, kAccessTimeColumnName), error)) { - return false; - } - - int64_t lastAccessTime = get_last_access_time(database_.get(), name_, error); - if (error) { - return false; - } - - lastAccessTime_.store(lastAccessTime, std::memory_order_seq_cst); - return true; -} +bool KeyValueStoreImpl::init(std::error_code& error) noexcept { return ensure_schema_exists(error); } bool KeyValueStoreImpl::exists(const Value& key, std::error_code& error) noexcept { - if (error) { + error = {}; // ensure "miss" can be distinguished from "error" + if (!ensure_schema_exists(error)) return false; - } - - auto query = database_->prepare_statement(get_key_count_statement(name_), error); - if (!query) { - return false; - } - - if (!bind_value(query.get(), get_key_storage_type(), key, 1, error)) { - return false; - } - - if (!query->step(error)) { - return false; - } - - return std::get(query->get_column_value(0, error)) > 0; -} - -bool KeyValueStoreImpl::updateValueAccessCountAndTime(const Value& key, - int64_t accessCount, - std::error_code& error) noexcept { - auto update = database_->prepare_statement(get_update_entry_access_statement(name_), error); - if (!update) { - return false; - } - - if (!bind_value(update.get(), StorageType::Integer, accessCount + 1, 1, error)) { + auto q = database_->prepare_statement(get_exists_statement(name_), error); + if (!q) return false; - } - - if (!bind_value(update.get(), StorageType::Integer, lastAccessTime_, 2, error)) { + if (!bind_value(q.get(), get_key_storage_type(), key, 1, error)) return false; - } - - if (!bind_value(update.get(), get_key_storage_type(), key, 3, error)) { + bool has_row = q->step(error); + if (error) return false; - } - - bool result = update->execute(error); - if (result) { - lastAccessTime_ += 1; - } - return result; + return has_row; } bool KeyValueStoreImpl::get(const Value& key, const std::function& fn, std::error_code& error, bool updateAccessStatistics) noexcept { + error = {}; // ensure "miss" can be distinguished from "error" + + if (!ensure_schema_exists(error)) + return false; + auto query = database_->prepare_statement(getQueryStatement(name_), error); if (!query) { return false; @@ -255,22 +220,47 @@ bool KeyValueStoreImpl::get(const Value& key, return false; } - if (!query->step(error)) { + bool has_row = query->step(error); + if (error) + return false; + if (!has_row) { + error = {}; return false; } auto value = query->get_column_value_no_copy(0, error); + + if (error) + return false; + fn(value); if (updateAccessStatistics) { - int64_t accessCount = std::get(query->get_column_value(1, error)); - return updateValueAccessCountAndTime(key, accessCount, error); + auto update = database_->prepare_statement(get_update_entry_access_statement(name_), error); + if (!update) + return false; + + auto next = lastAccessTime_.load(std::memory_order_acquire) + 1; + if (!bind_value(update.get(), StorageType::Integer, next, 1, error)) + return false; + if (!bind_value(update.get(), get_key_storage_type(), key, 2, error)) + return false; + bool ok = update->execute(error); + if (ok && !error) { + lastAccessTime_.store(next, std::memory_order_release); + } + return ok && !error; } return true; } bool KeyValueStoreImpl::put(const Value& key, const Value& value, std::error_code& error) noexcept { + error = {}; // clear error + + if (!ensure_schema_exists(error)) + return false; + auto statement = database_->prepare_statement(get_insert_or_replace_statement(name_), error); if (!statement) { return false; @@ -288,16 +278,41 @@ bool KeyValueStoreImpl::put(const Value& key, const Value& value, std::error_cod return false; } - if (!bind_value(statement.get(), StorageType::Integer, lastAccessTime_.load(std::memory_order_acquire), 4, error)) { + auto next = lastAccessTime_.load(std::memory_order_acquire) + 1; + if (!bind_value(statement.get(), StorageType::Integer, next, 4, error)) { return false; } + bool ok = statement->execute(error); + if (ok && !error) { + lastAccessTime_.store(next, std::memory_order_release); + } + return ok && !error; +} - lastAccessTime_ += 1; - return statement->execute(error); +bool KeyValueStoreImpl::ensure_schema_exists(std::error_code& error) noexcept { + error = {}; + if (!database_->execute(get_create_store_statement(name_, get_key_storage_type(), get_value_storage_type()), error)) + return false; + if (!database_->execute(get_create_index_statement(name_, kAccessCountColumnName), error)) + return false; + if (!database_->execute(get_create_index_statement(name_, kAccessTimeColumnName), error)) + return false; + + // Always recompute (cheap with the index). + auto t = get_last_access_time(database_.get(), name_, error); + if (error) + return false; + lastAccessTime_.store(t, std::memory_order_seq_cst); + return true; } bool KeyValueStoreImpl::remove(const Value& key, std::error_code& error) noexcept { + error = {}; // clear error + if (!ensure_schema_exists(error)) + return false; auto statement = database_->prepare_statement(get_remove_statement(name_), error); + if (!statement) + return false; if (!bind_value(statement.get(), get_key_storage_type(), key, 1, error)) { return false; } @@ -308,6 +323,9 @@ bool KeyValueStoreImpl::remove(const Value& key, std::error_code& error) noexcep bool KeyValueStoreImpl::get_keys_sorted_by_access_count(const std::function& fn, SortOrder order, std::error_code& error) noexcept { + error = {}; // clear error + if (!ensure_schema_exists(error)) + return false; auto statement = get_keys_sorted_by_column_statement(name(), kAccessCountColumnName, order); return execute(database_.get(), statement, 0, fn, error); } @@ -315,21 +333,27 @@ bool KeyValueStoreImpl::get_keys_sorted_by_access_count(const std::function& fn, SortOrder order, std::error_code& error) noexcept { + error = {}; // clear error + if (!ensure_schema_exists(error)) + return false; auto statement = get_keys_sorted_by_column_statement(name(), kAccessTimeColumnName, order); return execute(database_.get(), statement, 0, fn, error); } std::optional KeyValueStoreImpl::size(std::error_code& error) noexcept { + error = {}; // clear error + if (!ensure_schema_exists(error)) + return std::nullopt; int64_t count = database_->get_row_count(name_, error); return count < 0 ? std::nullopt : std::optional(count); } bool KeyValueStoreImpl::purge(std::error_code& error) noexcept { + error = {}; // clear error if (!database_->drop_table(name_, error)) { return false; } - - return init(error); + return ensure_schema_exists(error); } } // namespace sqlite diff --git a/backends/apple/coreml/runtime/kvstore/key_value_store.hpp b/backends/apple/coreml/runtime/kvstore/key_value_store.hpp index c02ad38059b..8ad0e74e1c0 100644 --- a/backends/apple/coreml/runtime/kvstore/key_value_store.hpp +++ b/backends/apple/coreml/runtime/kvstore/key_value_store.hpp @@ -7,12 +7,13 @@ #pragma once -#import +#include #include #include #include #include #include +#include #include #include @@ -27,10 +28,10 @@ namespace sqlite { template struct Converter { static constexpr StorageType storage_type = StorageType::Null; - + template static sqlite::Value to_sqlite_value(FROM&& value); - + static T from_sqlite_value(const sqlite::UnOwnedValue& value); }; @@ -38,11 +39,11 @@ struct Converter { template<> struct Converter { static constexpr StorageType storage_type = StorageType::Integer; - - static inline Value to_sqlite_value(int value) { + + static inline Value to_sqlite_value(int64_t value) { return value; } - + static inline int64_t from_sqlite_value(const sqlite::UnOwnedValue& value) { return std::get(value); } @@ -52,12 +53,12 @@ struct Converter { template<> struct Converter { static constexpr StorageType storage_type = StorageType::Integer; - + static inline Value to_sqlite_value(int value) { - return static_cast(value); + return static_cast(value); } - - static inline int from_sqlite_value(const sqlite::UnOwnedValue& value) { + + static inline int from_sqlite_value(const sqlite::UnOwnedValue& value) { return static_cast(std::get(value)); } }; @@ -66,11 +67,11 @@ struct Converter { template<> struct Converter { static constexpr StorageType storage_type = StorageType::Integer; - + static inline Value to_sqlite_value(size_t value) { - return static_cast(value); + return static_cast(value); } - + static inline size_t from_sqlite_value(const sqlite::UnOwnedValue& value) { return static_cast(std::get(value)); } @@ -80,12 +81,12 @@ struct Converter { template<> struct Converter { static constexpr StorageType storage_type = StorageType::Double; - + static inline Value to_sqlite_value(double value) { return value; } - - static inline int from_sqlite_value(const UnOwnedValue& value) { + + static inline double from_sqlite_value(const UnOwnedValue& value) { return std::get(value); } }; @@ -94,13 +95,14 @@ struct Converter { template<> struct Converter { static constexpr sqlite::StorageType storage_type = StorageType::Text; - + static inline sqlite::Value to_sqlite_value(const std::string& value) { return value; } - + static inline std::string from_sqlite_value(const UnOwnedValue& value) { - return std::string(std::get(value).data); + const auto s = std::get(value); + return std::string(s.data, s.size); } }; @@ -131,30 +133,30 @@ class KeyValueStoreImpl { get_value_storage_type_(get_value_storage_type), database_(std::move(database)) {} - + KeyValueStoreImpl(KeyValueStoreImpl const&) noexcept = delete; KeyValueStoreImpl& operator=(KeyValueStoreImpl const&) noexcept = delete; - + /// Returns the name of the store. inline std::string_view name() const noexcept { return name_; } - + /// Returns the key storage type. inline StorageType get_key_storage_type() const noexcept { return get_key_storage_type_; } - + /// Returns the value storage type. inline StorageType get_value_storage_type() const noexcept { return get_value_storage_type_; } - + /// Returns the sqlite database. inline Database *database() const noexcept { return database_.get(); } - + /// Returns the value for the specified key. If the key doesn't exists in the store or for some reason the operation failed /// then `nullopt` is returned. /// @@ -167,7 +169,7 @@ class KeyValueStoreImpl { const std::function& fn, std::error_code& error, bool update_access_statistics) noexcept; - + /// Returns `true` if the key exists in the store otherwise `false`. /// /// @param key The key. @@ -175,7 +177,11 @@ class KeyValueStoreImpl { /// @retval `true` if the key exists in the store otherwise `false`. bool exists(const Value& key, std::error_code& error) noexcept; - + + // Returns true if successful and false on error + // Ensures that the backing table exists. If the table doesn't exist then it is created. + bool ensure_schema_exists(std::error_code& error) noexcept; + /// Sorts the keys by the access count and calls the `std::function` on each key value. The sort order /// is specified by the `order` parameter. The caller can stop the iteration by returning `false` /// from the lambda, to continue the iteration the caller must return `true`. @@ -187,7 +193,7 @@ class KeyValueStoreImpl { bool get_keys_sorted_by_access_count(const std::function& fn, SortOrder order, std::error_code& error) noexcept; - + /// Sorts the keys by the access time and calls the `std::function` on each key value. The sort order /// is specified by the `order` parameter. The caller can stop the iteration by returning `false` /// from the lambda, to continue the iteration the caller must return `true`. @@ -199,7 +205,7 @@ class KeyValueStoreImpl { bool get_keys_sorted_by_access_time(const std::function& fn, SortOrder order, std::error_code& error) noexcept; - + /// Stores a key and a value in the store, the old value is overwritten. /// /// @param key The key. @@ -207,31 +213,31 @@ class KeyValueStoreImpl { /// @param error On failure, error is populated with the failure reason. /// @retval `true` if the operation succeeded otherwise `false`. bool put(const Value& key, const Value& value, std::error_code& error) noexcept; - + /// Removes the specified key and the associated value from the store. /// /// @param key The key. /// @param error On failure, error is populated with the failure reason. /// @retval `true` if the operation succeeded otherwise `false`. bool remove(const Value& key, std::error_code& error) noexcept; - + /// Purges the store. The backing table is dropped and re-created. bool purge(std::error_code& error) noexcept; - + /// Returns the size of the store. std::optional size(std::error_code& error) noexcept; - + /// Initializes the store. /// /// @param error On failure, error is populated with the failure reason. /// @retval `true` if the operation succeeded otherwise `false`. bool init(std::error_code& error) noexcept; - + private: bool updateValueAccessCountAndTime(const Value& key, int64_t accessCount, std::error_code& error) noexcept; - + std::string name_; StorageType get_key_storage_type_; StorageType get_value_storage_type_; @@ -245,16 +251,16 @@ class KeyValueStore final { public: template using same_key = std::is_same, Key>; template using same_value = std::is_same, Value>; - + virtual ~KeyValueStore() = default; - + KeyValueStore(KeyValueStore const&) noexcept = delete; KeyValueStore& operator=(KeyValueStore const&) noexcept = delete; - + inline KeyValueStore(std::unique_ptr impl) noexcept :impl_(std::move(impl)) {} - + /// Executes the provided lambda inside a transaction. The lambda must return `true` if the transaction is to /// be committed otherwise `false`. /// @@ -270,8 +276,8 @@ class KeyValueStore final { return fn(); }, behavior, error); } - - + + /// Returns the value for the specified key. If the key doesn't exists in the store or the operation failed /// then `nullopt` is returned. /// @@ -285,14 +291,14 @@ class KeyValueStore final { std::function fn = [&result](const UnOwnedValue& value) { result = ValueConverter::from_sqlite_value(value); }; - + if (!impl_->get(KeyConverter::to_sqlite_value(std::forward(key)), fn, error, update_access_statistics)) { return std::nullopt; } - + return result; } - + /// Returns `true` if the key exists in the store otherwise `false`. /// /// @param key The key. @@ -302,7 +308,7 @@ class KeyValueStore final { inline bool exists(T&& key, std::error_code& error) noexcept { return impl_->exists(KeyConverter::to_sqlite_value(std::forward(key)), error); } - + /// Stores a key and its associated value in the store, the old value is overwritten. /// /// @param key The key. @@ -315,7 +321,7 @@ class KeyValueStore final { ValueConverter::to_sqlite_value(std::forward(value)), error); } - + /// Sorts the keys by the access count and calls the lambda on each key value. The sort order /// is specified by the `order` parameter. The caller can stop the iteration by returning `false` /// from the lambda, to continue the iteration the caller must return `true`. @@ -331,10 +337,10 @@ class KeyValueStore final { std::function wrappedFn = [&fn](const UnOwnedValue& value) { return fn(KeyConverter::from_sqlite_value(value)); }; - + return impl_->get_keys_sorted_by_access_count(wrappedFn, order, error); } - + /// Sorts the keys by the access time and calls the lambda on each key value. The sort order /// is specified by the `order` parameter. The caller can stop the iteration by returning `false` /// from the lambda, to continue the iteration the caller must return `true`. @@ -350,10 +356,10 @@ class KeyValueStore final { std::function wrappedFn = [&fn](const UnOwnedValue& value) { return fn(KeyConverter::from_sqlite_value(value)); }; - + return impl_->get_keys_sorted_by_access_time(wrappedFn, order, error); } - + /// Removes the specified key and its associated value from the store. /// /// @param key The key. @@ -361,24 +367,24 @@ class KeyValueStore final { /// @retval `true` if the operation succeeded otherwise `false`. template inline bool remove(T&& key, std::error_code& error) noexcept { - return impl_->remove(Converter::to_sqlite_value(std::forward(key)), error); + return impl_->remove(KeyConverter::to_sqlite_value(std::forward(key)), error); } - + /// Returns the name of the store. inline std::string_view name() const noexcept { return impl_->name(); } - + /// Returns the size of the store. inline std::optional size(std::error_code& error) const noexcept { return impl_->size(error); } - + /// Purges the store. The backing table is dropped and re-created. inline bool purge(std::error_code& error) noexcept { return impl_->purge(error); } - + /// Creates a typed KeyValue store. /// /// The returned store's key type is `KeyType` and the value type is `ValueType`. The store @@ -398,10 +404,10 @@ class KeyValueStore final { if (!impl->init(error)) { return nullptr; } - + return std::make_unique>(std::move(impl)); } - + private: std::unique_ptr impl_; };