diff --git a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt index 79f351ae9cb..ac18a9bc7cd 100644 --- a/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/odbc/CMakeLists.txt @@ -34,6 +34,7 @@ else() endif() add_subdirectory(odbc_impl) +add_subdirectory(tests) arrow_install_all_headers("arrow/flight/sql/odbc") diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h index b3c1030f40e..a5cc3a6f4c8 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/encoding_utils.h @@ -36,6 +36,7 @@ namespace ODBC { using arrow::flight::sql::odbc::DriverException; using arrow::flight::sql::odbc::GetSqlWCharSize; using arrow::flight::sql::odbc::Utf8ToWcs; +using arrow::flight::sql::odbc::WcsToUtf8; // Return the number of bytes required for the conversion. template @@ -80,4 +81,24 @@ inline size_t ConvertToSqlWChar(const std::string& str, SQLWCHAR* buffer, } } +/// \brief Convert buffer of SqlWchar to standard string +/// \param[in] wchar_msg SqlWchar to convert +/// \param[in] msg_len Number of characters in wchar_msg +/// \return wchar_msg in std::string format +inline std::string SqlWcharToString(SQLWCHAR* wchar_msg, SQLINTEGER msg_len = SQL_NTS) { + if (msg_len == 0 || !wchar_msg || wchar_msg[0] == 0) { + return std::string(); + } + + thread_local std::vector utf8_str; + + if (msg_len == SQL_NTS) { + WcsToUtf8((void*)wchar_msg, &utf8_str); + } else { + WcsToUtf8((void*)wchar_msg, msg_len, &utf8_str); + } + + return std::string(utf8_str.begin(), utf8_str.end()); +} + } // namespace ODBC diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc index d10eff2580f..75501ac8dd4 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.cc @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h" + // platform.h includes windows.h, so it needs to be included // before winuser.h #include "arrow/flight/sql/odbc/odbc_impl/platform.h" @@ -33,13 +35,13 @@ #include #include -using arrow::flight::sql::odbc::DriverException; -using arrow::flight::sql::odbc::FlightSqlConnection; -using arrow::flight::sql::odbc::config::Configuration; -using arrow::flight::sql::odbc::config::ConnectionStringParser; -using arrow::flight::sql::odbc::config::DsnConfigurationWindow; -using arrow::flight::sql::odbc::config::Result; -using arrow::flight::sql::odbc::config::Window; +namespace arrow::flight::sql::odbc { + +using config::Configuration; +using config::ConnectionStringParser; +using config::DsnConfigurationWindow; +using config::Result; +using config::Window; bool DisplayConnectionWindow(void* window_parent, Configuration& config) { HWND hwnd_parent = (HWND)window_parent; @@ -237,3 +239,5 @@ BOOL INSTAPI ConfigDSNW(HWND hwnd_parent, WORD req, LPCWSTR wdriver, return TRUE; } + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h new file mode 100644 index 00000000000..32d17af6753 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/system_dsn.h @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// platform.h includes windows.h, so it needs to be included first +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h" + +namespace arrow::flight::sql::odbc { + +using config::Configuration; + +#if defined _WIN32 +/** + * Display connection window for user to configure connection parameters. + * + * @param window_parent Parent window handle. + * @param config Output configuration. + * @return True on success and false on fail. + */ +bool DisplayConnectionWindow(void* window_parent, Configuration& config); + +/** + * For SQLDriverConnect. + * Display connection window for user to configure connection parameters. + * + * @param window_parent Parent window handle. + * @param config Output configuration, presumed to be empty, it will be using values from + * properties. + * @param properties Output properties. + * @return True on success and false on fail. + */ +bool DisplayConnectionWindow(void* window_parent, Configuration& config, + Connection::ConnPropertyMap& properties); +#endif + +/** + * Register DSN with specified configuration. + * + * @param config Configuration. + * @param driver Driver. + * @return True on success and false on fail. + */ +bool RegisterDsn(const Configuration& config, LPCWSTR driver); + +/** + * Unregister specified DSN. + * + * @param dsn DSN name. + * @return True on success and false on fail. + */ +bool UnregisterDsn(const std::wstring& dsn); + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt new file mode 100644 index 00000000000..4bc240637e7 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/CMakeLists.txt @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +add_custom_target(tests) + +find_package(ODBC REQUIRED) +include_directories(${ODBC_INCLUDE_DIRS}) + +find_package(SQLite3Alt REQUIRED) + +set(ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS + ../../example/sqlite_sql_info.cc + ../../example/sqlite_type_info.cc + ../../example/sqlite_statement.cc + ../../example/sqlite_statement_batch_reader.cc + ../../example/sqlite_server.cc + ../../example/sqlite_tables_schema_batch_reader.cc) + +add_arrow_test(flight_sql_odbc_test + SOURCES + odbc_test_suite.cc + odbc_test_suite.h + connection_test.cc + # Enable Protobuf cleanup after test execution + # GH-46889: move protobuf_test_util to a more common location + ../../../../engine/substrait/protobuf_test_util.cc + ${ARROW_FLIGHT_SQL_MOCK_SERVER_SRCS} + EXTRA_LINK_LIBS + ${ODBC_LIBRARIES} + ${ODBCINST} + ${SQLite3_LIBRARIES} + arrow_odbc_spi_impl) diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc new file mode 100644 index 00000000000..fa1ccf2854f --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifdef _WIN32 +# include +#endif + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace arrow::flight::sql::odbc { + +TEST(SQLAllocHandle, SQLAllocHandleEnv) { + // Allocate an environment handle + SQLHENV env = nullptr; + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env)); + + // Check for valid handle + ASSERT_NE(nullptr, env); + + // Free an environment handle + ASSERT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env)); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc new file mode 100644 index 00000000000..fccb5525759 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.cc @@ -0,0 +1,504 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// For DSN registration. flight_sql_connection.h needs to included first due to conflicts +// with windows.h +#include "arrow/flight/sql/odbc/odbc_impl/flight_sql_connection.h" + +#include "arrow/flight/sql/odbc/tests/odbc_test_suite.h" + +// For DSN registration +#include "arrow/flight/sql/odbc/odbc_impl/config/configuration.h" +#include "arrow/flight/sql/odbc/odbc_impl/encoding_utils.h" +#include "arrow/flight/sql/odbc/odbc_impl/odbc_connection.h" + +namespace arrow::flight::sql::odbc { + +void FlightSQLODBCRemoteTestBase::AllocEnvConnHandles(SQLINTEGER odbc_ver) { + // Allocate an environment handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocEnv(&env)); + + ASSERT_EQ( + SQL_SUCCESS, + SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, + reinterpret_cast(static_cast(odbc_ver)), 0)); + + // Allocate a connection using alloc handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_DBC, env, &conn)); +} + +void FlightSQLODBCRemoteTestBase::Connect(SQLINTEGER odbc_ver) { + ASSERT_NO_FATAL_FAILURE(AllocEnvConnHandles(odbc_ver)); + std::string connect_str = GetConnectionString(); + ASSERT_NO_FATAL_FAILURE(ConnectWithString(connect_str)); +} + +void FlightSQLODBCRemoteTestBase::ConnectWithString(std::string connect_str) { + // Connect string + std::vector connect_str0(connect_str.begin(), connect_str.end()); + + SQLWCHAR out_str[kOdbcBufferSize]; + SQLSMALLINT out_str_len; + + // Connecting to ODBC server. + ASSERT_EQ(SQL_SUCCESS, + SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), out_str, + kOdbcBufferSize, &out_str_len, SQL_DRIVER_NOPROMPT)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); + + // Allocate a statement using alloc handle + ASSERT_EQ(SQL_SUCCESS, SQLAllocHandle(SQL_HANDLE_STMT, conn, &stmt)); +} + +void FlightSQLODBCRemoteTestBase::Disconnect() { + // Close statement + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_STMT, stmt)); + + // Disconnect from ODBC + EXPECT_EQ(SQL_SUCCESS, SQLDisconnect(conn)) + << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn); + + // Free connection handle + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_DBC, conn)); + + // Free environment handle + EXPECT_EQ(SQL_SUCCESS, SQLFreeHandle(SQL_HANDLE_ENV, env)); +} + +std::string FlightSQLODBCRemoteTestBase::GetConnectionString() { + std::string connect_str = + arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOrDie(); + return connect_str; +} + +std::string FlightSQLODBCRemoteTestBase::GetInvalidConnectionString() { + std::string connect_str = GetConnectionString(); + // Append invalid uid to connection string + connect_str += std::string("uid=non_existent_id;"); + return connect_str; +} + +std::wstring FlightSQLODBCRemoteTestBase::GetQueryAllDataTypes() { + std::wstring wsql = + LR"( SELECT + -- Numeric types + -128 as stiny_int_min, 127 as stiny_int_max, + 0 as utiny_int_min, 255 as utiny_int_max, + + -32768 as ssmall_int_min, 32767 as ssmall_int_max, + 0 as usmall_int_min, 65535 as usmall_int_max, + + CAST(-2147483648 AS INTEGER) AS sinteger_min, + CAST(2147483647 AS INTEGER) AS sinteger_max, + CAST(0 AS BIGINT) AS uinteger_min, + CAST(4294967295 AS BIGINT) AS uinteger_max, + + CAST(-9223372036854775808 AS BIGINT) AS sbigint_min, + CAST(9223372036854775807 AS BIGINT) AS sbigint_max, + CAST(0 AS BIGINT) AS ubigint_min, + --Use string to represent unsigned big int due to lack of support from + --remote test server + '18446744073709551615' AS ubigint_max, + + CAST(-999999999 AS DECIMAL(38, 0)) AS decimal_negative, + CAST(999999999 AS DECIMAL(38, 0)) AS decimal_positive, + + CAST(-3.40282347E38 AS FLOAT) AS float_min, CAST(3.40282347E38 AS FLOAT) AS float_max, + + CAST(-1.7976931348623157E308 AS DOUBLE) AS double_min, + CAST(1.7976931348623157E308 AS DOUBLE) AS double_max, + + --Boolean + CAST(false AS BOOLEAN) AS bit_false, + CAST(true AS BOOLEAN) AS bit_true, + + --Character types + 'Z' AS c_char, '你' AS c_wchar, + + '你好' AS c_wvarchar, + + 'XYZ' AS c_varchar, + + --Date / timestamp + CAST(DATE '1400-01-01' AS DATE) AS date_min, + CAST(DATE '9999-12-31' AS DATE) AS date_max, + + CAST(TIMESTAMP '1400-01-01 00:00:00' AS TIMESTAMP) AS timestamp_min, + CAST(TIMESTAMP '9999-12-31 23:59:59' AS TIMESTAMP) AS timestamp_max; + )"; + return wsql; +} + +void FlightSQLODBCRemoteTestBase::SetUp() { + if (arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOr("").empty()) { + GTEST_SKIP() << "Skipping test: kTestConnectStr not set"; + } + + this->Connect(); + connected_ = true; +} + +void FlightSQLODBCRemoteTestBase::TearDown() { + if (connected_) { + this->Disconnect(); + connected_ = false; + } +} + +void FlightSQLOdbcV2RemoteTestBase::SetUp() { + if (arrow::internal::GetEnvVar(kTestConnectStr.data()).ValueOr("").empty()) { + GTEST_SKIP() << "Skipping test: kTestConnectStr not set"; + } + + this->Connect(SQL_OV_ODBC2); + connected_ = true; +} + +std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (::toupper(char1) == ::toupper(char2)); + }; + + std::string bearer_token(""); + auto auth_header = incoming_headers.find(kAuthorizationHeader); + if (auth_header != incoming_headers.end()) { + const std::string auth_val(auth_header->second); + if (auth_val.size() > kBearerPrefix.length()) { + if (std::equal(auth_val.begin(), auth_val.begin() + kBearerPrefix.length(), + kBearerPrefix.begin(), char_compare)) { + bearer_token = auth_val.substr(kBearerPrefix.length()); + } + } + } + return bearer_token; +} + +void MockServerMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { + std::string bearer_token = FindTokenInCallHeaders(incoming_headers_); + *is_valid_ = (bearer_token == std::string(kTestToken)); +} + +Status MockServerMiddlewareFactory::StartCall( + const CallInfo& info, const ServerCallContext& context, + std::shared_ptr* middleware) { + std::string bearer_token = FindTokenInCallHeaders(context.incoming_headers()); + if (bearer_token == std::string(kTestToken)) { + *middleware = + std::make_shared(context.incoming_headers(), &is_valid_); + } else { + return MakeFlightError(FlightStatusCode::Unauthenticated, + "Invalid token for mock server"); + } + + return Status::OK(); +} + +std::string FlightSQLODBCMockTestBase::GetConnectionString() { + std::string connect_str( + "driver={Apache Arrow Flight SQL ODBC Driver};HOST=localhost;port=" + + std::to_string(port) + ";token=" + std::string(kTestToken) + + ";useEncryption=false;"); + return connect_str; +} + +std::string FlightSQLODBCMockTestBase::GetInvalidConnectionString() { + std::string connect_str = GetConnectionString(); + // Append invalid token to connection string + connect_str += std::string("token=invalid_token;"); + return connect_str; +} + +std::wstring FlightSQLODBCMockTestBase::GetQueryAllDataTypes() { + std::wstring wsql = + LR"( SELECT + -- Numeric types + -128 AS stiny_int_min, 127 AS stiny_int_max, + 0 AS utiny_int_min, 255 AS utiny_int_max, + + -32768 AS ssmall_int_min, 32767 AS ssmall_int_max, + 0 AS usmall_int_min, 65535 AS usmall_int_max, + + CAST(-2147483648 AS INTEGER) AS sinteger_min, + CAST(2147483647 AS INTEGER) AS sinteger_max, + CAST(0 AS INTEGER) AS uinteger_min, + CAST(4294967295 AS INTEGER) AS uinteger_max, + + CAST(-9223372036854775808 AS INTEGER) AS sbigint_min, + CAST(9223372036854775807 AS INTEGER) AS sbigint_max, + CAST(0 AS INTEGER) AS ubigint_min, + -- stored as TEXT as SQLite doesn't support unsigned big int + '18446744073709551615' AS ubigint_max, + + CAST('-999999999' AS NUMERIC) AS decimal_negative, + CAST('999999999' AS NUMERIC) AS decimal_positive, + + CAST(-3.40282347E38 AS REAL) AS float_min, + CAST(3.40282347E38 AS REAL) AS float_max, + + CAST(-1.7976931348623157E308 AS REAL) AS double_min, + CAST(1.7976931348623157E308 AS REAL) AS double_max, + + -- Boolean + 0 AS bit_false, + 1 AS bit_true, + + -- Character types + 'Z' AS c_char, + '你' AS c_wchar, + '你好' AS c_wvarchar, + 'XYZ' AS c_varchar, + + DATE('1400-01-01') AS date_min, + DATE('9999-12-31') AS date_max, + + DATETIME('1400-01-01 00:00:00') AS timestamp_min, + DATETIME('9999-12-31 23:59:59') AS timestamp_max; + )"; + return wsql; +} + +void FlightSQLODBCMockTestBase::CreateTestTables() { + ASSERT_OK(server_->ExecuteSql(R"( + CREATE TABLE TestTable ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + keyName varchar(100), + value int); + + INSERT INTO TestTable (keyName, value) VALUES ('One', 1); + INSERT INTO TestTable (keyName, value) VALUES ('Two', 0); + INSERT INTO TestTable (keyName, value) VALUES ('Three', -1); + )")); +} + +void FlightSQLODBCMockTestBase::CreateTableAllDataType() { + // Limitation on mock SQLite server: + // Only int64, float64, binary, and utf8 Arrow Types are supported by + // SQLiteFlightSqlServer::Impl::DoGetTables + ASSERT_OK(server_->ExecuteSql(R"( + CREATE TABLE AllTypesTable( + bigint_col INTEGER PRIMARY KEY AUTOINCREMENT, + char_col varchar(100), + varbinary_col BLOB, + double_col REAL); + + INSERT INTO AllTypesTable ( + char_col, + varbinary_col, + double_col) VALUES ( + '1st Row', + X'31737420726F77', + 3.14159 + ); + )")); +} + +void FlightSQLODBCMockTestBase::CreateUnicodeTable() { + std::string unicode_sql = arrow::util::WideStringToUTF8( + LR"( + CREATE TABLE 数据( + 资料 varchar(100)); + + INSERT INTO 数据 (资料) VALUES ('第一行'); + INSERT INTO 数据 (资料) VALUES ('二行'); + INSERT INTO 数据 (资料) VALUES ('3rd Row'); + )") + .ValueOr(""); + ASSERT_OK(server_->ExecuteSql(unicode_sql)); +} + +void FlightSQLODBCMockTestBase::Initialize() { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options(location); + options.auth_handler = std::make_unique(); + options.middleware.push_back( + {"bearer-auth-server", std::make_shared()}); + ASSERT_OK_AND_ASSIGN(server_, + arrow::flight::sql::example::SQLiteFlightSqlServer::Create()); + ASSERT_OK(server_->Init(options)); + + port = server_->port(); + ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", port)); + ASSERT_OK_AND_ASSIGN(auto client, arrow::flight::FlightClient::Connect(location)); +} + +void FlightSQLODBCMockTestBase::SetUp() { + this->Initialize(); + this->Connect(); + connected_ = true; +} + +void FlightSQLODBCMockTestBase::TearDown() { + if (connected_) { + this->Disconnect(); + connected_ = false; + } + ASSERT_OK(server_->Shutdown()); +} + +void FlightSQLOdbcV2MockTestBase::SetUp() { + this->Initialize(); + this->Connect(SQL_OV_ODBC2); + connected_ = true; +} + +bool CompareConnPropertyMap(Connection::ConnPropertyMap map1, + Connection::ConnPropertyMap map2) { + if (map1.size() != map2.size()) return false; + + for (const auto& [key, value] : map1) { + if (value != map2[key]) return false; + } + + return true; +} + +void VerifyOdbcErrorState(SQLSMALLINT handle_type, SQLHANDLE handle, + std::string_view expected_state) { + using ODBC::SqlWcharToString; + + SQLWCHAR sql_state[7] = {}; + SQLINTEGER native_code; + + SQLWCHAR message[kOdbcBufferSize] = {}; + SQLSMALLINT real_len = 0; + + // On Windows, real_len is in bytes. On Linux, real_len is in chars. + // So, not using real_len + SQLGetDiagRec(handle_type, handle, 1, sql_state, &native_code, message, kOdbcBufferSize, + &real_len); + + EXPECT_EQ(expected_state, SqlWcharToString(sql_state)); +} + +std::string GetOdbcErrorMessage(SQLSMALLINT handle_type, SQLHANDLE handle) { + using ODBC::SqlWcharToString; + + SQLWCHAR sql_state[7] = {}; + SQLINTEGER native_code; + + SQLWCHAR message[kOdbcBufferSize] = {}; + SQLSMALLINT real_len = 0; + + // On Windows, real_len is in bytes. On Linux, real_len is in chars. + // So, not using real_len + SQLGetDiagRec(handle_type, handle, 1, sql_state, &native_code, message, kOdbcBufferSize, + &real_len); + + std::string res = SqlWcharToString(sql_state); + + if (res.empty() || !message[0]) { + res = "Cannot find ODBC error message"; + } else { + res.append(": ").append(SqlWcharToString(message)); + } + + return res; +} + +// TODO: once RegisterDsn is implemented in Mac and Linux, the following can be +// re-enabled. +#if defined _WIN32 +bool WriteDSN(std::string connection_str) { + Connection::ConnPropertyMap properties; + + ODBC::ODBCConnection::GetPropertiesFromConnString(connection_str, properties); + return WriteDSN(properties); +} + +bool WriteDSN(Connection::ConnPropertyMap properties) { + using arrow::flight::sql::odbc::Connection; + using arrow::flight::sql::odbc::FlightSqlConnection; + using arrow::flight::sql::odbc::config::Configuration; + using ODBC::ODBCConnection; + + Configuration config; + config.Set(FlightSqlConnection::DSN, std::string(kTestDsn)); + + for (const auto& [key, value] : properties) { + config.Set(key, value); + } + + std::string driver = config.Get(FlightSqlConnection::DRIVER); + std::wstring w_driver = arrow::util::UTF8ToWideString(driver).ValueOr(L""); + return RegisterDsn(config, w_driver.c_str()); +} +#endif + +std::wstring ConvertToWString(const std::vector& str_val, SQLSMALLINT str_len) { + std::wstring attr_str; + if (str_len == 0) { + attr_str = std::wstring(&str_val[0]); + } else { + EXPECT_GT(str_len, 0); + EXPECT_LE(str_len, static_cast(kOdbcBufferSize)); + attr_str = std::wstring(str_val.begin(), + str_val.begin() + str_len / ODBC::GetSqlWCharSize()); + } + return attr_str; +} + +void CheckStringColumnW(SQLHSTMT stmt, int col_id, const std::wstring& expected) { + SQLWCHAR buf[1024]; + SQLLEN buf_len = sizeof(buf) * ODBC::GetSqlWCharSize(); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(stmt, col_id, SQL_C_WCHAR, buf, buf_len, &buf_len)); + + EXPECT_GT(buf_len, 0); + + // returned buf_len is in bytes so convert to length in characters + size_t char_count = static_cast(buf_len) / ODBC::GetSqlWCharSize(); + std::wstring returned(buf, buf + char_count); + + EXPECT_EQ(expected, returned); +} + +void CheckNullColumnW(SQLHSTMT stmt, int col_id) { + SQLWCHAR buf[1024]; + SQLLEN buf_len = sizeof(buf); + + ASSERT_EQ(SQL_SUCCESS, SQLGetData(stmt, col_id, SQL_C_WCHAR, buf, buf_len, &buf_len)); + + EXPECT_EQ(SQL_NULL_DATA, buf_len); +} + +void CheckIntColumn(SQLHSTMT stmt, int col_id, const SQLINTEGER& expected) { + SQLINTEGER buf; + SQLLEN buf_len = sizeof(buf); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(stmt, col_id, SQL_C_LONG, &buf, sizeof(buf), &buf_len)); + + EXPECT_EQ(expected, buf); +} + +void CheckSmallIntColumn(SQLHSTMT stmt, int col_id, const SQLSMALLINT& expected) { + SQLSMALLINT buf; + SQLLEN buf_len = sizeof(buf); + + ASSERT_EQ(SQL_SUCCESS, + SQLGetData(stmt, col_id, SQL_C_SSHORT, &buf, sizeof(buf), &buf_len)); + + EXPECT_EQ(expected, buf); +} + +void ValidateFetch(SQLHSTMT stmt, SQLRETURN expected_return) { + ASSERT_EQ(expected_return, SQLFetch(stmt)); +} + +} // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h new file mode 100644 index 00000000000..e35e6c38f85 --- /dev/null +++ b/cpp/src/arrow/flight/sql/odbc/tests/odbc_test_suite.h @@ -0,0 +1,254 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/testing/gtest_util.h" +#include "arrow/util/io_util.h" +#include "arrow/util/utf8.h" + +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/example/sqlite_server.h" +#include "arrow/flight/sql/odbc/odbc_impl/encoding_utils.h" +#include "arrow/flight/sql/odbc/odbc_impl/platform.h" + +#include +#include +#include + +#include + +#include "arrow/flight/sql/odbc/odbc_impl/odbc_connection.h" + +// For DSN registration +#include "arrow/flight/sql/odbc/odbc_impl/system_dsn.h" + +static constexpr std::string_view kTestConnectStr = "ARROW_FLIGHT_SQL_ODBC_CONN"; +static constexpr std::string_view kTestDsn = "Apache Arrow Flight SQL Test DSN"; + +namespace arrow::flight::sql::odbc { + +/// \brief Base test fixture for running tests against a remote server. +/// Each test file running remote server tests should define a +/// fixture inheriting from this base fixture. +/// The connection string for connecting to this server is defined +/// in the ARROW_FLIGHT_SQL_ODBC_CONN environment variable. +class FlightSQLODBCRemoteTestBase : public ::testing::Test { + public: + /// \brief Allocate environment and connection handles + void AllocEnvConnHandles(SQLINTEGER odbc_ver = SQL_OV_ODBC3); + /// \brief Connect to Arrow Flight SQL server using connection string defined in + /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN", allocate statement handle. + /// Connects using ODBC Ver 3 by default + void Connect(SQLINTEGER odbc_ver = SQL_OV_ODBC3); + /// \brief Connect to Arrow Flight SQL server using connection string + void ConnectWithString(std::string connection_str); + /// \brief Disconnect from server + void Disconnect(); + /// \brief Get connection string from environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" + std::string virtual GetConnectionString(); + /// \brief Get invalid connection string based on connection string defined in + /// environment variable "ARROW_FLIGHT_SQL_ODBC_CONN" + std::string virtual GetInvalidConnectionString(); + /// \brief Return a SQL query that selects all data types + std::wstring virtual GetQueryAllDataTypes(); + + /** ODBC Environment. */ + SQLHENV env = 0; + + /** ODBC Connect. */ + SQLHDBC conn = 0; + + /** ODBC Statement. */ + SQLHSTMT stmt = 0; + + protected: + void SetUp() override; + + void TearDown() override; + + bool connected_ = false; +}; + +/// \brief Base test fixture for running ODBC V2 tests against a remote server. +/// Each test file running remote server ODBC V2 tests should define a +/// fixture inheriting from this base fixture. +class FlightSQLOdbcV2RemoteTestBase : public FlightSQLODBCRemoteTestBase { + protected: + void SetUp() override; +}; + +static constexpr std::string_view kAuthorizationHeader = "authorization"; +static constexpr std::string_view kBearerPrefix = "Bearer "; +static constexpr std::string_view kTestToken = "t0k3n"; + +std::string FindTokenInCallHeaders(const CallHeaders& incoming_headers); + +// A server middleware for validating incoming bearer header authentication. +class MockServerMiddleware : public ServerMiddleware { + public: + explicit MockServerMiddleware(const CallHeaders& incoming_headers, bool* is_valid) + : is_valid_(is_valid) { + incoming_headers_ = incoming_headers; + } + + void SendingHeaders(AddCallHeaders* outgoing_headers) override; + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "MockServerMiddleware"; } + + private: + CallHeaders incoming_headers_; + bool* is_valid_; +}; + +// Factory for base64 header authentication testing. +class MockServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + MockServerMiddlewareFactory() : is_valid_(false) {} + + Status StartCall(const CallInfo& info, const ServerCallContext& context, + std::shared_ptr* middleware) override; + + private: + bool is_valid_; +}; + +/// \brief Base test fixture for running tests against a mock server. +/// Each test file running mock server tests should define a +/// fixture inheriting from this base fixture. +class FlightSQLODBCMockTestBase : public FlightSQLODBCRemoteTestBase { + // Sets up a mock server for each test case + public: + /// \brief Get connection string for mock server + std::string GetConnectionString() override; + /// \brief Get invalid connection string for mock server + std::string GetInvalidConnectionString() override; + /// \brief Return a SQL query that selects all data types + std::wstring GetQueryAllDataTypes() override; + + /// \brief Run a SQL query to create default table for table test cases + void CreateTestTables(); + + /// \brief run a SQL query to create a table with all data types + void CreateTableAllDataType(); + /// \brief run a SQL query to create a table with unicode name + void CreateUnicodeTable(); + + int port; + + protected: + void Initialize(); + + void SetUp() override; + + void TearDown() override; + + private: + std::shared_ptr server_; +}; + +/// \brief Base test fixture for running ODBC V2 tests against a mock server. +/// Each test file running mock server ODBC V2 tests should define a +/// fixture inheriting from this base fixture. +class FlightSQLOdbcV2MockTestBase : public FlightSQLODBCMockTestBase { + protected: + void SetUp() override; +}; + +/** ODBC read buffer size. */ +static constexpr int kOdbcBufferSize = 1024; + +/// Compare ConnPropertyMap, key value is case-insensitive +bool CompareConnPropertyMap(Connection::ConnPropertyMap map1, + Connection::ConnPropertyMap map2); + +/// Get error message from ODBC driver using SQLGetDiagRec +std::string GetOdbcErrorMessage(SQLSMALLINT handle_type, SQLHANDLE handle); + +static constexpr std::string_view kErrorState01004 = "01004"; +static constexpr std::string_view kErrorState01S07 = "01S07"; +static constexpr std::string_view kErrorState01S02 = "01S02"; +static constexpr std::string_view kErrorState07009 = "07009"; +static constexpr std::string_view kErrorState08003 = "08003"; +static constexpr std::string_view kErrorState22002 = "22002"; +static constexpr std::string_view kErrorState24000 = "24000"; +static constexpr std::string_view kErrorState28000 = "28000"; +static constexpr std::string_view kErrorStateHY000 = "HY000"; +static constexpr std::string_view kErrorStateHY004 = "HY004"; +static constexpr std::string_view kErrorStateHY009 = "HY009"; +static constexpr std::string_view kErrorStateHY010 = "HY010"; +static constexpr std::string_view kErrorStateHY017 = "HY017"; +static constexpr std::string_view kErrorStateHY024 = "HY024"; +static constexpr std::string_view kErrorStateHY090 = "HY090"; +static constexpr std::string_view kErrorStateHY091 = "HY091"; +static constexpr std::string_view kErrorStateHY092 = "HY092"; +static constexpr std::string_view kErrorStateHY106 = "HY106"; +static constexpr std::string_view kErrorStateHY114 = "HY114"; +static constexpr std::string_view kErrorStateHY118 = "HY118"; +static constexpr std::string_view kErrorStateHYC00 = "HYC00"; +static constexpr std::string_view kErrorStateS1004 = "S1004"; + +/// Verify ODBC Error State +void VerifyOdbcErrorState(SQLSMALLINT handle_type, SQLHANDLE handle, + std::string_view expected_state); + +/// \brief Write connection string into DSN +/// \param[in] connection_str the connection string. +/// \return true on success +bool WriteDSN(std::string connection_str); + +/// \brief Write properties map into DSN +/// \param[in] properties map. +/// \return true on success +bool WriteDSN(Connection::ConnPropertyMap properties); + +/// \brief Check wide char vector and convert into wstring +/// \param[in] str_val Vector of SQLWCHAR. +/// \param[in] str_len length of string, in bytes. +/// \return wstring +std::wstring ConvertToWString(const std::vector& str_val, SQLSMALLINT str_len); + +/// \brief Check wide string column. +/// \param[in] stmt Statement. +/// \param[in] col_id Column ID to check. +/// \param[in] expected Expected value. +void CheckStringColumnW(SQLHSTMT stmt, int col_id, const std::wstring& expected); + +/// \brief Check wide string column value is null. +/// \param[in] stmt Statement. +/// \param[in] col_id Column ID to check. +void CheckNullColumnW(SQLHSTMT stmt, int col_id); + +/// \brief Check int column. +/// \param[in] stmt Statement. +/// \param[in] col_id Column ID to check. +/// \param[in] expected Expected value. +void CheckIntColumn(SQLHSTMT stmt, int col_id, const SQLINTEGER& expected); + +/// \brief Check smallint column. +/// \param[in] stmt Statement. +/// \param[in] col_id Column ID to check. +/// \param[in] expected Expected value. +void CheckSmallIntColumn(SQLHSTMT stmt, int col_id, const SQLSMALLINT& expected); + +/// \brief Check sql return against expected. +/// \param[in] stmt Statement. +/// \param[in] expected Expected return. +void ValidateFetch(SQLHSTMT stmt, SQLRETURN expected); + +} // namespace arrow::flight::sql::odbc