diff --git a/cpp/src/arrow/flight/sql/odbc/entry_points.cc b/cpp/src/arrow/flight/sql/odbc/entry_points.cc index 63faa57e45b..38b4a1fc8ed 100644 --- a/cpp/src/arrow/flight/sql/odbc/entry_points.cc +++ b/cpp/src/arrow/flight/sql/odbc/entry_points.cc @@ -287,3 +287,13 @@ SQLRETURN SQL_API SQLSetStmtAttr(SQLHSTMT stmt, SQLINTEGER attribute, SQLPOINTER SQLINTEGER stringLength) { return arrow::SQLSetStmtAttr(stmt, attribute, valuePtr, stringLength); } + +SQLRETURN SQL_API SQLDescribeCol(SQLHSTMT statementHandle, SQLUSMALLINT columnNumber, + SQLWCHAR* columnName, SQLSMALLINT bufferLength, + SQLSMALLINT* nameLengthPtr, SQLSMALLINT* dataTypePtr, + SQLULEN* columnSizePtr, SQLSMALLINT* decimalDigitsPtr, + SQLSMALLINT* nullablePtr) { + return arrow::SQLDescribeCol(statementHandle, columnNumber, columnName, bufferLength, + nameLengthPtr, dataTypePtr, columnSizePtr, + decimalDigitsPtr, nullablePtr); +} diff --git a/cpp/src/arrow/flight/sql/odbc/odbc.def b/cpp/src/arrow/flight/sql/odbc/odbc.def index 6a7402ffa90..8ba5b3fff78 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc.def +++ b/cpp/src/arrow/flight/sql/odbc/odbc.def @@ -28,6 +28,7 @@ EXPORTS SQLColAttributeW SQLColumnsW SQLConnectW + SQLDescribeColW SQLDisconnect SQLDriverConnectW SQLExecDirectW diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc index 461e17fae5f..3ccde05ca5b 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc @@ -1367,4 +1367,118 @@ SQLRETURN SQLNativeSql(SQLHDBC connectionHandle, SQLWCHAR* inStatementText, }); } +SQLRETURN SQLDescribeCol(SQLHSTMT stmt, SQLUSMALLINT columnNumber, SQLWCHAR* columnName, + SQLSMALLINT bufferLength, SQLSMALLINT* nameLengthPtr, + SQLSMALLINT* dataTypePtr, SQLULEN* columnSizePtr, + SQLSMALLINT* decimalDigitsPtr, SQLSMALLINT* nullablePtr) { + LOG_DEBUG( + "SQLDescribeColW called with stmt: {}, columnNumber: {}, " + "columnName: {}, bufferLength: {}, nameLengthPtr: {}, dataTypePtr: {}, " + "columnSizePtr: {}, decimalDigitsPtr: {}, nullablePtr: {}", + stmt, columnNumber, fmt::ptr(columnName), bufferLength, fmt::ptr(nameLengthPtr), + fmt::ptr(dataTypePtr), fmt::ptr(columnSizePtr), fmt::ptr(decimalDigitsPtr), + fmt::ptr(nullablePtr)); + using ODBC::ODBCDescriptor; + using ODBC::ODBCStatement; + + return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() { + ODBCStatement* statement = reinterpret_cast(stmt); + ODBCDescriptor* ird = statement->GetIRD(); + SQLINTEGER outputLengthInt; + SQLSMALLINT sqlType; + + // Column SQL Type + ird->GetField(columnNumber, SQL_DESC_CONCISE_TYPE, &sqlType, sizeof(SQLSMALLINT), + nullptr); + if (dataTypePtr) { + *dataTypePtr = sqlType; + } + + // Column Name + if (columnName || nameLengthPtr) { + ird->GetField(columnNumber, SQL_DESC_NAME, columnName, bufferLength, + &outputLengthInt); + if (nameLengthPtr) { + *nameLengthPtr = static_cast(outputLengthInt); + } + } + + // Column Size + if (columnSizePtr) { + switch (sqlType) { + // All numeric types + case SQL_DECIMAL: + case SQL_NUMERIC: + case SQL_TINYINT: + case SQL_SMALLINT: + case SQL_INTEGER: + case SQL_BIGINT: + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: { + ird->GetField(columnNumber, SQL_DESC_PRECISION, columnSizePtr, sizeof(SQLULEN), + nullptr); + break; + } + + default: { + ird->GetField(columnNumber, SQL_DESC_LENGTH, columnSizePtr, sizeof(SQLULEN), + nullptr); + } + } + } + + // Column Decimal Digits + if (decimalDigitsPtr) { + switch (sqlType) { + // All exact numeric types + case SQL_TINYINT: + case SQL_SMALLINT: + case SQL_INTEGER: + case SQL_BIGINT: + case SQL_DECIMAL: + case SQL_NUMERIC: { + ird->GetField(columnNumber, SQL_DESC_SCALE, decimalDigitsPtr, sizeof(SQLULEN), + nullptr); + break; + } + + // All datetime types (ODBC2) + case SQL_DATE: + case SQL_TIME: + case SQL_TIMESTAMP: + // All datetime types (ODBC3) + case SQL_TYPE_DATE: + case SQL_TYPE_TIME: + case SQL_TYPE_TIMESTAMP: + // All interval types with a seconds component + case SQL_INTERVAL_SECOND: + case SQL_INTERVAL_MINUTE_TO_SECOND: + case SQL_INTERVAL_HOUR_TO_SECOND: + case SQL_INTERVAL_DAY_TO_SECOND: { + ird->GetField(columnNumber, SQL_DESC_PRECISION, decimalDigitsPtr, + sizeof(SQLULEN), nullptr); + break; + } + + default: { + // All character and binary types + // SQL_BIT + // All approximate numeric types + // All interval types with no seconds component + *decimalDigitsPtr = static_cast(0); + } + } + } + + // Column Nullable + if (nullablePtr) { + ird->GetField(columnNumber, SQL_DESC_NULLABLE, nullablePtr, sizeof(SQLSMALLINT), + nullptr); + } + + return SQL_SUCCESS; + }); +} + } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.h b/cpp/src/arrow/flight/sql/odbc/odbc_api.h index 6aa2ec681af..94a7dc0ec3e 100644 --- a/cpp/src/arrow/flight/sql/odbc/odbc_api.h +++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.h @@ -96,4 +96,9 @@ SQLRETURN SQLGetTypeInfo(SQLHSTMT stmt, SQLSMALLINT dataType); SQLRETURN SQLNativeSql(SQLHDBC connectionHandle, SQLWCHAR* inStatementText, SQLINTEGER inStatementTextLength, SQLWCHAR* outStatementText, SQLINTEGER bufferLength, SQLINTEGER* outStatementTextLength); +SQLRETURN SQLDescribeCol(SQLHSTMT statementHandle, SQLUSMALLINT columnNumber, + SQLWCHAR* columnName, SQLSMALLINT bufferLength, + SQLSMALLINT* nameLengthPtr, SQLSMALLINT* dataTypePtr, + SQLULEN* columnSizePtr, SQLSMALLINT* decimalDigitsPtr, + SQLSMALLINT* nullablePtr); } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc index 08a86d4a840..04ca7ec9b96 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/columns_test.cc @@ -26,9 +26,6 @@ #include "gtest/gtest.h" -// TODO: add tests with SQLDescribeCol to check metadata of SQLColumns for ODBC 2 and -// ODBC 3. - namespace arrow::flight::sql::odbc { // Helper functions void checkSQLColumns( @@ -2321,4 +2318,558 @@ TYPED_TEST(FlightSQLODBCTestBase, TestSQLColAttributesUpdatable) { this->disconnect(); } + +TEST_F(FlightSQLODBCMockTestBase, SQLDescribeColValidateInput) { + this->connect(); + this->CreateTestTables(); + + SQLSMALLINT columnCount = 0; + SQLSMALLINT expectedValue = 3; + SQLWCHAR sqlQuery[] = L"SELECT * FROM TestTable LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLUSMALLINT bookmarkColumn = 0; + SQLUSMALLINT validColumn = 1; + SQLUSMALLINT outOfRangeColumn = 4; + SQLUSMALLINT negativeColumn = -1; + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT dataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + // Invalid descriptor index - Bookmarks are not supported + ret = SQLDescribeCol(this->stmt, bookmarkColumn, columnName, bufCharLen, &nameLength, + &dataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_07009); + + // Invalid descriptor index - index out of range + ret = SQLDescribeCol(this->stmt, outOfRangeColumn, columnName, bufCharLen, &nameLength, + &dataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_07009); + + // Invalid descriptor index - index out of range + ret = SQLDescribeCol(this->stmt, negativeColumn, columnName, bufCharLen, &nameLength, + &dataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_ERROR); + VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, error_state_07009); + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLDescribeColQueryAllDataTypesMetadata) { + // Mock server has a limitation where only SQL_WVARCHAR column type values are returned + // from SELECT AS queries + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + SQLWCHAR* columnNames[] = { + (SQLWCHAR*)L"stiny_int_min", (SQLWCHAR*)L"stiny_int_max", + (SQLWCHAR*)L"utiny_int_min", (SQLWCHAR*)L"utiny_int_max", + (SQLWCHAR*)L"ssmall_int_min", (SQLWCHAR*)L"ssmall_int_max", + (SQLWCHAR*)L"usmall_int_min", (SQLWCHAR*)L"usmall_int_max", + (SQLWCHAR*)L"sinteger_min", (SQLWCHAR*)L"sinteger_max", + (SQLWCHAR*)L"uinteger_min", (SQLWCHAR*)L"uinteger_max", + (SQLWCHAR*)L"sbigint_min", (SQLWCHAR*)L"sbigint_max", + (SQLWCHAR*)L"ubigint_min", (SQLWCHAR*)L"ubigint_max", + (SQLWCHAR*)L"decimal_negative", (SQLWCHAR*)L"decimal_positive", + (SQLWCHAR*)L"float_min", (SQLWCHAR*)L"float_max", + (SQLWCHAR*)L"double_min", (SQLWCHAR*)L"double_max", + (SQLWCHAR*)L"bit_false", (SQLWCHAR*)L"bit_true", + (SQLWCHAR*)L"c_char", (SQLWCHAR*)L"c_wchar", + (SQLWCHAR*)L"c_wvarchar", (SQLWCHAR*)L"c_varchar", + (SQLWCHAR*)L"date_min", (SQLWCHAR*)L"date_max", + (SQLWCHAR*)L"timestamp_min", (SQLWCHAR*)L"timestamp_max"}; + SQLSMALLINT columnDataTypes[] = { + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR}; + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, 1024); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLDescribeColQueryAllDataTypesMetadata) { + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + std::wstring wsql = this->getQueryAllDataTypes(); + std::vector sql0(wsql.begin(), wsql.end()); + + SQLWCHAR* columnNames[] = { + (SQLWCHAR*)L"stiny_int_min", (SQLWCHAR*)L"stiny_int_max", + (SQLWCHAR*)L"utiny_int_min", (SQLWCHAR*)L"utiny_int_max", + (SQLWCHAR*)L"ssmall_int_min", (SQLWCHAR*)L"ssmall_int_max", + (SQLWCHAR*)L"usmall_int_min", (SQLWCHAR*)L"usmall_int_max", + (SQLWCHAR*)L"sinteger_min", (SQLWCHAR*)L"sinteger_max", + (SQLWCHAR*)L"uinteger_min", (SQLWCHAR*)L"uinteger_max", + (SQLWCHAR*)L"sbigint_min", (SQLWCHAR*)L"sbigint_max", + (SQLWCHAR*)L"ubigint_min", (SQLWCHAR*)L"ubigint_max", + (SQLWCHAR*)L"decimal_negative", (SQLWCHAR*)L"decimal_positive", + (SQLWCHAR*)L"float_min", (SQLWCHAR*)L"float_max", + (SQLWCHAR*)L"double_min", (SQLWCHAR*)L"double_max", + (SQLWCHAR*)L"bit_false", (SQLWCHAR*)L"bit_true", + (SQLWCHAR*)L"c_char", (SQLWCHAR*)L"c_wchar", + (SQLWCHAR*)L"c_wvarchar", (SQLWCHAR*)L"c_varchar", + (SQLWCHAR*)L"date_min", (SQLWCHAR*)L"date_max", + (SQLWCHAR*)L"timestamp_min", (SQLWCHAR*)L"timestamp_max"}; + SQLSMALLINT columnDataTypes[] = { + SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, + SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, SQL_INTEGER, + SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, SQL_BIGINT, + SQL_WVARCHAR, SQL_DECIMAL, SQL_DECIMAL, SQL_FLOAT, SQL_FLOAT, + SQL_DOUBLE, SQL_DOUBLE, SQL_BIT, SQL_BIT, SQL_WVARCHAR, + SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_TYPE_DATE, SQL_TYPE_DATE, + SQL_TYPE_TIMESTAMP, SQL_TYPE_TIMESTAMP}; + SQLULEN columnSizes[] = {4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 8, + 8, 8, 8, 8, 65536, 19, 19, 8, 8, 8, 8, + 1, 1, 65536, 65536, 65536, 65536, 10, 10, 23, 23}; + SQLULEN columnDecimalDigits[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 10, 23, 23}; + + SQLRETURN ret = + SQLExecDirect(this->stmt, &sql0[0], static_cast(sql0.size())); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, columnDecimalDigits[i]); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLDescribeColODBCTestTableMetadata) { + // Test assumes there is a table $scratch.ODBCTest in remote server + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR sqlQuery[] = L"SELECT * from $scratch.ODBCTest LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"sinteger_max", (SQLWCHAR*)L"sbigint_max", + (SQLWCHAR*)L"decimal_positive", (SQLWCHAR*)L"float_max", + (SQLWCHAR*)L"double_max", (SQLWCHAR*)L"bit_true", + (SQLWCHAR*)L"date_max", (SQLWCHAR*)L"time_max", + (SQLWCHAR*)L"timestamp_max"}; + SQLSMALLINT columnDataTypes[] = {SQL_INTEGER, SQL_BIGINT, SQL_DECIMAL, + SQL_FLOAT, SQL_DOUBLE, SQL_BIT, + SQL_TYPE_DATE, SQL_TYPE_TIME, SQL_TYPE_TIMESTAMP}; + SQLULEN columnSizes[] = {4, 8, 19, 8, 8, 1, 10, 12, 23}; + SQLULEN columnDecimalDigits[] = {0, 0, 0, 0, 0, 0, 10, 12, 23}; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, columnDecimalDigits[i]); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCRemoteTestBase, SQLDescribeColODBCTestTableMetadataODBC2) { + // Test assumes there is a table $scratch.ODBCTest in remote server + this->connect(SQL_OV_ODBC2); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR sqlQuery[] = L"SELECT * from $scratch.ODBCTest LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"sinteger_max", (SQLWCHAR*)L"sbigint_max", + (SQLWCHAR*)L"decimal_positive", (SQLWCHAR*)L"float_max", + (SQLWCHAR*)L"double_max", (SQLWCHAR*)L"bit_true", + (SQLWCHAR*)L"date_max", (SQLWCHAR*)L"time_max", + (SQLWCHAR*)L"timestamp_max"}; + SQLSMALLINT columnDataTypes[] = {SQL_INTEGER, SQL_BIGINT, SQL_DECIMAL, + SQL_FLOAT, SQL_DOUBLE, SQL_BIT, + SQL_DATE, SQL_TIME, SQL_TIMESTAMP}; + SQLULEN columnSizes[] = {4, 8, 19, 8, 8, 1, 10, 12, 23}; + SQLULEN columnDecimalDigits[] = {0, 0, 0, 0, 0, 0, 10, 12, 23}; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, columnDecimalDigits[i]); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLDescribeColAllTypesTableMetadata) { + this->connect(); + this->CreateTableAllDataType(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR sqlQuery[] = L"SELECT * from AllTypesTable LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"bigint_col", (SQLWCHAR*)L"char_col", + (SQLWCHAR*)L"varbinary_col", (SQLWCHAR*)L"double_col"}; + SQLSMALLINT columnDataTypes[] = {SQL_BIGINT, SQL_WVARCHAR, SQL_BINARY, SQL_DOUBLE}; + SQLULEN columnSizes[] = {8, 0, 0, 8}; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TEST_F(FlightSQLODBCMockTestBase, SQLDescribeColUnicodeTableMetadata) { + this->connect(); + this->CreateUnicodeTable(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 1; + + SQLWCHAR sqlQuery[] = L"SELECT * from 数据 LIMIT 1;"; + SQLINTEGER queryLength = static_cast(wcslen(sqlQuery)); + + SQLWCHAR expectedColumnName[] = L"资料"; + SQLSMALLINT expectedColumnDataType = SQL_WVARCHAR; + SQLULEN expectedColumnSize = 0; + + SQLRETURN ret = SQLExecDirect(this->stmt, sqlQuery, queryLength); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLFetch(this->stmt); + + EXPECT_EQ(ret, SQL_SUCCESS); + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, expectedColumnName); + EXPECT_EQ(columnDataType, expectedColumnDataType); + EXPECT_EQ(columnSize, expectedColumnSize); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColumnsGetMetadataBySQLDescribeCol) { + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"TABLE_CAT", (SQLWCHAR*)L"TABLE_SCHEM", + (SQLWCHAR*)L"TABLE_NAME", (SQLWCHAR*)L"COLUMN_NAME", + (SQLWCHAR*)L"DATA_TYPE"}; + SQLSMALLINT columnDataTypes[] = {SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_SMALLINT}; + SQLULEN columnSizes[] = {1024, 1024, 1024, 1024, 2}; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLColumnsGetMetadataBySQLDescribeColODBC2) { + this->connect(SQL_OV_ODBC2); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"TABLE_QUALIFIER", (SQLWCHAR*)L"TABLE_OWNER", + (SQLWCHAR*)L"TABLE_NAME", (SQLWCHAR*)L"COLUMN_NAME", + (SQLWCHAR*)L"DATA_TYPE"}; + SQLSMALLINT columnDataTypes[] = {SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_SMALLINT}; + SQLULEN columnSizes[] = {1024, 1024, 1024, 1024, 2}; + + SQLRETURN ret = SQLColumns(this->stmt, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} } // namespace arrow::flight::sql::odbc diff --git a/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc b/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc index 1dc0c90245e..68405c51583 100644 --- a/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc +++ b/cpp/src/arrow/flight/sql/odbc/tests/tables_test.cc @@ -28,9 +28,6 @@ namespace arrow::flight::sql::odbc { -// TODO: Add tests with SQLDescribeCol to check metadata of SQLColumns for ODBC 2 and -// ODBC 3. - // Helper Functions std::wstring GetStringColumnW(SQLHSTMT stmt, int colId) { @@ -578,4 +575,109 @@ TEST_F(FlightSQLODBCRemoteTestBase, SQLTablesGetSupportedTableTypes) { this->disconnect(); } +TYPED_TEST(FlightSQLODBCTestBase, SQLTablesGetMetadataBySQLDescribeCol) { + this->connect(); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"TABLE_CAT", (SQLWCHAR*)L"TABLE_SCHEM", + (SQLWCHAR*)L"TABLE_NAME", (SQLWCHAR*)L"TABLE_TYPE", + (SQLWCHAR*)L"REMARKS"}; + SQLSMALLINT columnDataTypes[] = {SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR}; + SQLULEN columnSizes[] = {1024, 1024, 1024, 1024, 1024}; + + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, + SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} + +TYPED_TEST(FlightSQLODBCTestBase, SQLTablesGetMetadataBySQLDescribeColODBC2) { + this->connect(SQL_OV_ODBC2); + + SQLWCHAR columnName[1024]; + constexpr SQLINTEGER bufCharLen = sizeof(columnName) / ODBC::GetSqlWCharSize(); + SQLSMALLINT nameLength = 0; + SQLSMALLINT columnDataType = 0; + SQLULEN columnSize = 0; + SQLSMALLINT decimalDigits = 0; + SQLSMALLINT nullable = 0; + size_t columnIndex = 0; + + SQLWCHAR* columnNames[] = {(SQLWCHAR*)L"TABLE_QUALIFIER", (SQLWCHAR*)L"TABLE_OWNER", + (SQLWCHAR*)L"TABLE_NAME", (SQLWCHAR*)L"TABLE_TYPE", + (SQLWCHAR*)L"REMARKS"}; + SQLSMALLINT columnDataTypes[] = {SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, SQL_WVARCHAR, + SQL_WVARCHAR}; + SQLULEN columnSizes[] = {1024, 1024, 1024, 1024, 1024}; + + SQLRETURN ret = SQLTables(this->stmt, nullptr, SQL_NTS, nullptr, SQL_NTS, nullptr, + SQL_NTS, nullptr, SQL_NTS); + + EXPECT_EQ(ret, SQL_SUCCESS); + + for (size_t i = 0; i < sizeof(columnNames) / sizeof(*columnNames); ++i) { + columnIndex = i + 1; + + ret = SQLDescribeCol(this->stmt, columnIndex, columnName, bufCharLen, &nameLength, + &columnDataType, &columnSize, &decimalDigits, &nullable); + + EXPECT_EQ(ret, SQL_SUCCESS); + + EXPECT_GT(nameLength, 0); + + // Returned nameLength is in bytes so convert to length in characters + size_t charCount = static_cast(nameLength) / ODBC::GetSqlWCharSize(); + std::wstring returned(columnName, columnName + charCount); + EXPECT_EQ(returned, columnNames[i]); + EXPECT_EQ(columnDataType, columnDataTypes[i]); + EXPECT_EQ(columnSize, columnSizes[i]); + EXPECT_EQ(decimalDigits, 0); + EXPECT_EQ(nullable, SQL_NULLABLE); + + nameLength = 0; + columnDataType = 0; + columnSize = 0; + decimalDigits = 0; + nullable = 0; + } + + this->disconnect(); +} } // namespace arrow::flight::sql::odbc