diff --git a/cpp/src/arrow/flight/sql/odbc/entry_points.cc b/cpp/src/arrow/flight/sql/odbc/entry_points.cc index 0fc55720938..46536b84e9b 100644 --- a/cpp/src/arrow/flight/sql/odbc/entry_points.cc +++ b/cpp/src/arrow/flight/sql/odbc/entry_points.cc @@ -59,6 +59,14 @@ SQLRETURN SQL_API SQLGetDiagFieldW(SQLSMALLINT handleType, SQLHANDLE handle, diagInfoPtr, bufferLength, stringLengthPtr); } +SQLRETURN SQL_API SQLGetDiagRecW(SQLSMALLINT handleType, SQLHANDLE handle, + SQLSMALLINT recNumber, SQLWCHAR* sqlState, + SQLINTEGER* nativeErrorPtr, SQLWCHAR* messageText, + SQLSMALLINT bufferLength, SQLSMALLINT* textLengthPtr) { + return arrow::SQLGetDiagRecW(handleType, handle, recNumber, sqlState, nativeErrorPtr, + messageText, bufferLength, textLengthPtr); +} + SQLRETURN SQL_API SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, SQLINTEGER bufferLen, SQLINTEGER* strLenPtr) { return arrow::SQLGetEnvAttr(env, attr, valuePtr, bufferLen, strLenPtr); @@ -81,14 +89,6 @@ SQLRETURN SQL_API SQLGetInfoW(SQLHDBC conn, SQLUSMALLINT infoType, return arrow::SQLGetInfoW(conn, infoType, infoValuePtr, bufLen, length); } -SQLRETURN SQL_API SQLGetDiagRecW(SQLSMALLINT type, SQLHANDLE handle, SQLSMALLINT recNum, - SQLWCHAR* sqlState, SQLINTEGER* nativeError, - SQLWCHAR* msgBuffer, SQLSMALLINT msgBufferLen, - SQLSMALLINT* msgLen) { - // TODO implement SQLGetDiagRecW - return SQL_ERROR; -} - SQLRETURN SQL_API SQLDriverConnectW(SQLHDBC conn, SQLHWND windowHandle, SQLWCHAR* inConnectionString, SQLSMALLINT inConnectionStringLen, diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index 9d2df40c05a..02b597d14e9 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -230,13 +230,93 @@ SQLRETURN SQLGetDiagFieldW(SQLSMALLINT handleType, SQLHANDLE handle, stringLengthPtr, *diagnostics); } - default: - return SQL_ERROR; + default: { + // TODO Return correct dummy values + return SQL_SUCCESS; + } } return SQL_ERROR; } +SQLRETURN SQLGetDiagRecW(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber, + SQLWCHAR* sqlState, SQLINTEGER* nativeErrorPtr, + SQLWCHAR* messageText, SQLSMALLINT bufferLength, + SQLSMALLINT* textLengthPtr) { + using driver::odbcabstraction::Diagnostics; + using ODBC::ConvertToSqlWChar; + using ODBC::GetStringAttribute; + using ODBC::ODBCConnection; + using ODBC::ODBCEnvironment; + + if (!handle) { + return SQL_INVALID_HANDLE; + } + + // Record number must be greater or equal to 1 + if (recNumber < 1 || bufferLength < 0) { + return SQL_ERROR; + } + + // Set character type to be Unicode by default + const bool isUnicode = true; + Diagnostics* diagnostics = nullptr; + + switch (handleType) { + case SQL_HANDLE_ENV: { + auto* environment = ODBCEnvironment::of(handle); + diagnostics = &environment->GetDiagnostics(); + break; + } + + case SQL_HANDLE_DBC: { + auto* connection = ODBCConnection::of(handle); + diagnostics = &connection->GetDiagnostics(); + break; + } + + case SQL_HANDLE_DESC: { + return SQL_ERROR; + } + + case SQL_HANDLE_STMT: { + return SQL_ERROR; + } + + default: + return SQL_INVALID_HANDLE; + } + + if (!diagnostics) { + return SQL_ERROR; + } + + // Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage + const size_t recordIndex = static_cast(recNumber - 1); + if (!diagnostics->HasRecord(recordIndex)) { + return SQL_NO_DATA; + } + + if (sqlState) { + // The length of the sql state is always 5 characters plus null + SQLSMALLINT size = 6; + const std::string& state = diagnostics->GetSQLState(recordIndex); + GetStringAttribute(isUnicode, state, false, sqlState, size, &size, *diagnostics); + } + + if (nativeErrorPtr) { + *nativeErrorPtr = diagnostics->GetNativeError(recordIndex); + } + + if (messageText || textLengthPtr) { + const std::string& message = diagnostics->GetMessageText(recordIndex); + return GetStringAttribute(isUnicode, message, false, messageText, bufferLength, + textLengthPtr, *diagnostics); + } + + return SQL_SUCCESS; +} + SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, SQLINTEGER bufferLen, SQLINTEGER* strLenPtr) { using driver::odbcabstraction::DriverException; diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.h b/cpp/src/arrow/flight/sql/odbc/odbc_api.h index 14bbddaa3ce..d5c392bcf59 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.h @@ -35,6 +35,10 @@ SQLRETURN SQLGetDiagFieldW(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier, SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength, SQLSMALLINT* stringLengthPtr); +SQLRETURN SQLGetDiagRecW(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber, + SQLWCHAR* sqlState, SQLINTEGER* nativeErrorPtr, + SQLWCHAR* messageText, SQLSMALLINT bufferLength, + SQLSMALLINT* textLengthPtr); SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, SQLINTEGER bufferLen, SQLINTEGER* strLenPtr); SQLRETURN SQLSetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr, diff --git a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc index c6a39ccc361..039a5fd074e 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc @@ -400,8 +400,7 @@ TEST(SQLDriverConnect, TestSQLDriverConnectInvalidUid) { EXPECT_TRUE(ret == SQL_ERROR); - // TODO uncomment this check when SQLGetDiagRec is implemented - // VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, std::string("28000")); + VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, std::string("28000")); // TODO: Check that outstr remains empty after SqlWcharToString // is fixed to handle empty `outstr` @@ -621,8 +620,7 @@ TEST(SQLConnect, TestSQLConnectInvalidUid) { // so connection still fails despite passing valid uid in SQLConnect call EXPECT_TRUE(ret == SQL_ERROR); - // TODO uncomment this check when SQLGetDiagRec is implemented - // VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, std::string("28000")); + VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, std::string("28000")); // Remove DSN EXPECT_TRUE(UnregisterDsn(dsn)); @@ -747,6 +745,82 @@ TEST(SQLDisconnect, TestSQLDisconnectWithoutConnection) { EXPECT_TRUE(ret == SQL_SUCCESS); } + +TEST(SQLGetDiagRec, TestSQLGetDiagRecForConnectFailure) { + // ODBC Environment + SQLHENV env; + SQLHDBC conn; + + // Allocate an environment handle + SQLRETURN ret = SQLAllocEnv(&env); + + EXPECT_TRUE(ret == SQL_SUCCESS); + + ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + EXPECT_TRUE(ret == SQL_SUCCESS); + + // Allocate a connection using alloc handle + ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn); + + EXPECT_TRUE(ret == SQL_SUCCESS); + + // Connect string + ASSERT_OK_AND_ASSIGN(std::string connect_str, + arrow::internal::GetEnvVar(TEST_CONNECT_STR)); + // Append invalid uid to connection string + connect_str += std::string("uid=non_existent_id;"); + + ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str, + arrow::util::UTF8ToWideString(connect_str)); + std::vector connect_str0(wconnect_str.begin(), wconnect_str.end()); + + SQLWCHAR outstr[ODBC_BUFFER_SIZE]; + SQLSMALLINT outstrlen; + + // Connecting to ODBC server. + ret = SQLDriverConnect(conn, NULL, &connect_str0[0], + static_cast(connect_str0.size()), outstr, + ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT); + + EXPECT_TRUE(ret == SQL_ERROR); + + if (ret != SQL_SUCCESS) { + std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl; + } + + SQLWCHAR sql_state[6]; + SQLINTEGER native_error; + SQLWCHAR message[ODBC_BUFFER_SIZE]; + SQLSMALLINT message_length; + + ret = SQLGetDiagRec(SQL_HANDLE_DBC, conn, 1, sql_state, &native_error, message, + ODBC_BUFFER_SIZE, &message_length); + + EXPECT_TRUE(ret == SQL_SUCCESS); + + EXPECT_TRUE(message_length > 200); + + EXPECT_TRUE(native_error == 200); + + // 28000 + EXPECT_TRUE(sql_state[0] == '2'); + EXPECT_TRUE(sql_state[1] == '8'); + EXPECT_TRUE(sql_state[2] == '0'); + EXPECT_TRUE(sql_state[3] == '0'); + EXPECT_TRUE(sql_state[4] == '0'); + + // Free connection handle + ret = SQLFreeHandle(SQL_HANDLE_DBC, conn); + + EXPECT_TRUE(ret == SQL_SUCCESS); + + // Free environment handle + ret = SQLFreeHandle(SQL_HANDLE_ENV, env); + + EXPECT_TRUE(ret == SQL_SUCCESS); +} + } // namespace integration_tests } // namespace odbc } // namespace flight