From eba12625e87a8238646ddce4e2a98c9f8ba69eab Mon Sep 17 00:00:00 2001 From: Vlad Lyutenko Date: Fri, 23 Sep 2022 17:39:04 +0200 Subject: [PATCH] Call sp_rename as prepared statement Prevents sql injection --- .../plugin/sqlserver/SqlServerClient.java | 49 ++++++++++--------- .../sqlserver/TestSqlServerConnectorTest.java | 35 ++++++++++++- 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java index b5885879c358..20c6d7479c8f 100644 --- a/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java +++ b/plugin/trino-sqlserver/src/main/java/io/trino/plugin/sqlserver/SqlServerClient.java @@ -302,29 +302,40 @@ protected void verifyColumnName(DatabaseMetaData databaseMetadata, String column protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException { + // sp_rename treats first argument as SQL object name, so it needs to be properly quoted and escaped. + // The second argument is treated literally. if (!remoteSchemaName.equals(newRemoteSchemaName)) { throw new TrinoException(NOT_SUPPORTED, "This connector does not support renaming tables across schemas"); } - - execute(connection, format( - "sp_rename %s, %s", - singleQuote(catalogName, remoteSchemaName, remoteTableName), - singleQuote(newRemoteTableName))); + String fullTableFromName = DOT_JOINER.join( + quoted(catalogName), + quoted(remoteSchemaName), + quoted(remoteTableName)); + + try (CallableStatement renameTable = connection.prepareCall("exec sp_rename ?, ?")) { + renameTable.setString(1, fullTableFromName); + renameTable.setString(2, newRemoteTableName); + renameTable.execute(); + } } @Override protected void renameColumn(ConnectorSession session, Connection connection, RemoteTableName remoteTableName, String remoteColumnName, String newRemoteColumnName) throws SQLException { - execute(connection, format( - "sp_rename %s, %s, 'COLUMN'", - singleQuote(remoteTableName.getCatalogName().orElse(null), remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName(), "[" + escape(remoteColumnName) + "]"), - "[" + newRemoteColumnName + "]")); - } - - private static String escape(String name) - { - return name.replace("'", "''"); + // sp_rename treats first argument as SQL object name, so it needs to be properly quoted and escaped. + // The second arqgument is treated literally. + String columnFrom = DOT_JOINER.join( + quoted(remoteTableName.getCatalogName().orElseThrow()), + quoted(remoteTableName.getSchemaName().orElseThrow()), + quoted(remoteTableName.getTableName()), + quoted(remoteColumnName)); + + try (CallableStatement renameColumn = connection.prepareCall("exec sp_rename ?, ?, 'COLUMN'")) { + renameColumn.setString(1, columnFrom); + renameColumn.setString(2, newRemoteColumnName); + renameColumn.execute(); + } } @Override @@ -974,16 +985,6 @@ public Connection getConnection(ConnectorSession session, JdbcOutputTableHandle return connection; } - private static String singleQuote(String... objects) - { - return singleQuote(DOT_JOINER.join(objects)); - } - - private static String singleQuote(String literal) - { - return "\'" + literal + "\'"; - } - public static ColumnMapping varbinaryColumnMapping() { return ColumnMapping.sliceMapping( diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerConnectorTest.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerConnectorTest.java index c73ae51aeab9..e7328891456d 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerConnectorTest.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerConnectorTest.java @@ -155,7 +155,7 @@ public void testInsertWriteBulkinessWithTimestamps(String timestampType) } } - // TODO move test to BaseConnectorTest + // TODO move test to BaseConnectorTest https://github.com/trinodb/trino/issues/14517 @Test(dataProvider = "testTableNameDataProvider") public void testCreateAndDropTableWithSpecialCharacterName(String tableName) { @@ -170,6 +170,39 @@ public void testCreateAndDropTableWithSpecialCharacterName(String tableName) assertFalse(getQueryRunner().tableExists(getSession(), tableName)); } + // TODO remove this test after https://github.com/trinodb/trino/issues/14517 + @Test(dataProvider = "testTableNameDataProvider") + public void testRenameColumnNameAdditionalTests(String columnName) + { + String nameInSql = "\"" + columnName.replace("\"", "\"\"") + "\""; + String tableName = "tcn_" + nameInSql.replaceAll("[^a-z0-9]", "") + randomTableSuffix(); + // Use complex identifier to test a source column name when renaming columns + String sourceColumnName = "a;b$c"; + + assertUpdate("CREATE TABLE " + tableName + "(\"" + sourceColumnName + "\" varchar(50))"); + assertTableColumnNames(tableName, sourceColumnName); + + assertUpdate("ALTER TABLE " + tableName + " RENAME COLUMN \"" + sourceColumnName + "\" TO " + nameInSql); + assertTableColumnNames(tableName, columnName.toLowerCase(ENGLISH)); + + assertUpdate("DROP TABLE " + tableName); + } + + // TODO move this test to BaseConnectorTest https://github.com/trinodb/trino/issues/14517 + @Test(dataProvider = "testTableNameDataProvider") + public void testRenameFromToTableWithSpecialCharacterName(String tableName) + { + String tableNameInSql = "\"" + tableName.replace("\"", "\"\"") + "\""; + String sourceTableName = "test_rename_source_" + randomTableSuffix(); + assertUpdate("CREATE TABLE " + sourceTableName + " AS SELECT 123 x", 1); + + assertUpdate("ALTER TABLE " + sourceTableName + " RENAME TO " + tableNameInSql); + assertQuery("SELECT x FROM " + tableNameInSql, "VALUES 123"); + // test rename back is working properly + assertUpdate("ALTER TABLE " + tableNameInSql + " RENAME TO " + sourceTableName); + assertUpdate("DROP TABLE " + sourceTableName); + } + private int getTableOperationsCount(String operation, String table) throws SQLException {