Skip to content
16 changes: 8 additions & 8 deletions cpp/src/arrow/flight/sql/odbc/entry_points.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ SQLRETURN SQL_API SQLGetDiagFieldW(SQLSMALLINT handleType, SQLHANDLE handle,
diagInfoPtr, bufferLength, stringLengthPtr);
}

SQLRETURN SQL_API SQLGetDiagRecW(SQLSMALLINT handleType, SQLHANDLE handle,
SQLSMALLINT recNumber, SQLWCHAR* sqlState,
SQLINTEGER* nativeErrorPtr, SQLWCHAR* messageText,
SQLSMALLINT bufferLength, SQLSMALLINT* textLengthPtr) {
return arrow::SQLGetDiagRecW(handleType, handle, recNumber, sqlState, nativeErrorPtr,
messageText, bufferLength, textLengthPtr);
}

SQLRETURN SQL_API SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr,
SQLINTEGER bufferLen, SQLINTEGER* strLenPtr) {
return arrow::SQLGetEnvAttr(env, attr, valuePtr, bufferLen, strLenPtr);
Expand All @@ -81,14 +89,6 @@ SQLRETURN SQL_API SQLGetInfoW(SQLHDBC conn, SQLUSMALLINT infoType,
return arrow::SQLGetInfoW(conn, infoType, infoValuePtr, bufLen, length);
}

SQLRETURN SQL_API SQLGetDiagRecW(SQLSMALLINT type, SQLHANDLE handle, SQLSMALLINT recNum,
SQLWCHAR* sqlState, SQLINTEGER* nativeError,
SQLWCHAR* msgBuffer, SQLSMALLINT msgBufferLen,
SQLSMALLINT* msgLen) {
// TODO implement SQLGetDiagRecW
return SQL_ERROR;
}

SQLRETURN SQL_API SQLDriverConnectW(SQLHDBC conn, SQLHWND windowHandle,
SQLWCHAR* inConnectionString,
SQLSMALLINT inConnectionStringLen,
Expand Down
84 changes: 82 additions & 2 deletions cpp/src/arrow/flight/sql/odbc/odbc_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,13 +230,93 @@ SQLRETURN SQLGetDiagFieldW(SQLSMALLINT handleType, SQLHANDLE handle,
stringLengthPtr, *diagnostics);
}

default:
return SQL_ERROR;
default: {
// TODO Return correct dummy values
return SQL_SUCCESS;
}
}

return SQL_ERROR;
}

SQLRETURN SQLGetDiagRecW(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber,
SQLWCHAR* sqlState, SQLINTEGER* nativeErrorPtr,
SQLWCHAR* messageText, SQLSMALLINT bufferLength,
SQLSMALLINT* textLengthPtr) {
using driver::odbcabstraction::Diagnostics;
using ODBC::ConvertToSqlWChar;
using ODBC::GetStringAttribute;
using ODBC::ODBCConnection;
using ODBC::ODBCEnvironment;

if (!handle) {
return SQL_INVALID_HANDLE;
}

// Record number must be greater or equal to 1
if (recNumber < 1 || bufferLength < 0) {
return SQL_ERROR;
}

// Set character type to be Unicode by default
const bool isUnicode = true;
Diagnostics* diagnostics = nullptr;

switch (handleType) {
case SQL_HANDLE_ENV: {
auto* environment = ODBCEnvironment::of(handle);
diagnostics = &environment->GetDiagnostics();
break;
}

case SQL_HANDLE_DBC: {
auto* connection = ODBCConnection::of(handle);
diagnostics = &connection->GetDiagnostics();
break;
}

case SQL_HANDLE_DESC: {
return SQL_ERROR;
}

case SQL_HANDLE_STMT: {
return SQL_ERROR;
}

default:
return SQL_INVALID_HANDLE;
}

if (!diagnostics) {
return SQL_ERROR;
}

// Convert from ODBC 1 based record number to internal diagnostics 0 indexed storage
const size_t recordIndex = static_cast<size_t>(recNumber - 1);
if (!diagnostics->HasRecord(recordIndex)) {
return SQL_NO_DATA;
}

if (sqlState) {
// The length of the sql state is always 5 characters plus null
SQLSMALLINT size = 6;
const std::string& state = diagnostics->GetSQLState(recordIndex);
GetStringAttribute(isUnicode, state, false, sqlState, size, &size, *diagnostics);
}

if (nativeErrorPtr) {
*nativeErrorPtr = diagnostics->GetNativeError(recordIndex);
}

if (messageText || textLengthPtr) {
const std::string& message = diagnostics->GetMessageText(recordIndex);
return GetStringAttribute(isUnicode, message, false, messageText, bufferLength,
textLengthPtr, *diagnostics);
}

return SQL_SUCCESS;
}

SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr,
SQLINTEGER bufferLen, SQLINTEGER* strLenPtr) {
using driver::odbcabstraction::DriverException;
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/flight/sql/odbc/odbc_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ SQLRETURN SQLGetDiagFieldW(SQLSMALLINT handleType, SQLHANDLE handle,
SQLSMALLINT recNumber, SQLSMALLINT diagIdentifier,
SQLPOINTER diagInfoPtr, SQLSMALLINT bufferLength,
SQLSMALLINT* stringLengthPtr);
SQLRETURN SQLGetDiagRecW(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT recNumber,
SQLWCHAR* sqlState, SQLINTEGER* nativeErrorPtr,
SQLWCHAR* messageText, SQLSMALLINT bufferLength,
SQLSMALLINT* textLengthPtr);
SQLRETURN SQLGetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr,
SQLINTEGER bufferLen, SQLINTEGER* strLenPtr);
SQLRETURN SQLSetEnvAttr(SQLHENV env, SQLINTEGER attr, SQLPOINTER valuePtr,
Expand Down
82 changes: 78 additions & 4 deletions cpp/src/arrow/flight/sql/odbc/tests/connection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,7 @@ TEST(SQLDriverConnect, TestSQLDriverConnectInvalidUid) {

EXPECT_TRUE(ret == SQL_ERROR);

// TODO uncomment this check when SQLGetDiagRec is implemented
// VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, std::string("28000"));
VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, std::string("28000"));

// TODO: Check that outstr remains empty after SqlWcharToString
// is fixed to handle empty `outstr`
Expand Down Expand Up @@ -621,8 +620,7 @@ TEST(SQLConnect, TestSQLConnectInvalidUid) {
// so connection still fails despite passing valid uid in SQLConnect call
EXPECT_TRUE(ret == SQL_ERROR);

// TODO uncomment this check when SQLGetDiagRec is implemented
// VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, std::string("28000"));
VerifyOdbcErrorState(SQL_HANDLE_DBC, conn, std::string("28000"));

// Remove DSN
EXPECT_TRUE(UnregisterDsn(dsn));
Expand Down Expand Up @@ -747,6 +745,82 @@ TEST(SQLDisconnect, TestSQLDisconnectWithoutConnection) {

EXPECT_TRUE(ret == SQL_SUCCESS);
}

TEST(SQLGetDiagRec, TestSQLGetDiagRecForConnectFailure) {
// ODBC Environment
SQLHENV env;
SQLHDBC conn;

// Allocate an environment handle
SQLRETURN ret = SQLAllocEnv(&env);

EXPECT_TRUE(ret == SQL_SUCCESS);

ret = SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0);

EXPECT_TRUE(ret == SQL_SUCCESS);

// Allocate a connection using alloc handle
ret = SQLAllocHandle(SQL_HANDLE_DBC, env, &conn);

EXPECT_TRUE(ret == SQL_SUCCESS);

// Connect string
ASSERT_OK_AND_ASSIGN(std::string connect_str,
arrow::internal::GetEnvVar(TEST_CONNECT_STR));
// Append invalid uid to connection string
connect_str += std::string("uid=non_existent_id;");

ASSERT_OK_AND_ASSIGN(std::wstring wconnect_str,
arrow::util::UTF8ToWideString(connect_str));
std::vector<SQLWCHAR> connect_str0(wconnect_str.begin(), wconnect_str.end());

SQLWCHAR outstr[ODBC_BUFFER_SIZE];
SQLSMALLINT outstrlen;

// Connecting to ODBC server.
ret = SQLDriverConnect(conn, NULL, &connect_str0[0],
static_cast<SQLSMALLINT>(connect_str0.size()), outstr,
ODBC_BUFFER_SIZE, &outstrlen, SQL_DRIVER_NOPROMPT);

EXPECT_TRUE(ret == SQL_ERROR);

if (ret != SQL_SUCCESS) {
std::cerr << GetOdbcErrorMessage(SQL_HANDLE_DBC, conn) << std::endl;
}
Comment on lines +788 to +790

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the code for debugging, it can be removed


SQLWCHAR sql_state[6];
SQLINTEGER native_error;
SQLWCHAR message[ODBC_BUFFER_SIZE];
SQLSMALLINT message_length;

ret = SQLGetDiagRec(SQL_HANDLE_DBC, conn, 1, sql_state, &native_error, message,
ODBC_BUFFER_SIZE, &message_length);

EXPECT_TRUE(ret == SQL_SUCCESS);

EXPECT_TRUE(message_length > 200);

EXPECT_TRUE(native_error == 200);

// 28000
EXPECT_TRUE(sql_state[0] == '2');
EXPECT_TRUE(sql_state[1] == '8');
EXPECT_TRUE(sql_state[2] == '0');
EXPECT_TRUE(sql_state[3] == '0');
EXPECT_TRUE(sql_state[4] == '0');

// Free connection handle
ret = SQLFreeHandle(SQL_HANDLE_DBC, conn);

EXPECT_TRUE(ret == SQL_SUCCESS);

// Free environment handle
ret = SQLFreeHandle(SQL_HANDLE_ENV, env);

EXPECT_TRUE(ret == SQL_SUCCESS);
}

} // namespace integration_tests
} // namespace odbc
} // namespace flight
Expand Down