From 2dde452e4a03aa8a85628d96c0bf8aff0c5e9601 Mon Sep 17 00:00:00 2001 From: Rick Hillegas Date: Mon, 19 Oct 2015 13:00:32 -0700 Subject: [PATCH 1/6] SPARK-10857: Vet table names which are inserted into SQL strings to verify that they are legal SQL identifiers in the underlying database. --- .../apache/spark/sql/SqlIdentifierUtil.java | 322 ++++++++++++++++++ .../apache/spark/sql/DataFrameWriter.scala | 6 +- .../datasources/jdbc/JdbcUtils.scala | 19 +- .../apache/spark/sql/jdbc/JdbcDialects.scala | 71 +++- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 56 +++ .../spark/sql/jdbc/JDBCWriteSuite.scala | 25 ++ 6 files changed, 485 insertions(+), 14 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/SqlIdentifierUtil.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/SqlIdentifierUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/SqlIdentifierUtil.java new file mode 100644 index 000000000000..ff05c53b6693 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/SqlIdentifierUtil.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql; + +import java.io.IOException; +import java.io.StringReader; +import java.util.Locale; +import java.util.Vector; +import org.apache.spark.SparkException; + +/** + * Methods for handling SQL identifiers. These methods were cribbed + * from org.apache.derby.iapi.util.IdUtil and + * org.apache.derby.iapi.util.StringUtil. + */ +public class SqlIdentifierUtil { + + /** + * Quote a string so that it can be used as an identifier or a string + * literal in SQL statements. Identifiers are surrounded by double quotes + * and string literals are surrounded by single quotes. If the string + * contains quote characters, they are escaped. + * + * @param source the string to quote + * @param quote the character to quote the string with (' or ") + * @return a string quoted with the specified quote character + */ + public static String quoteString(String source, char quote) { + // Normally, the quoted string is two characters longer than the source + // string (because of start quote and end quote). + StringBuffer quoted = new StringBuffer(source.length() + 2); + quoted.append(quote); + for (int i = 0; i < source.length(); i++) { + char c = source.charAt(i); + // if the character is a quote, escape it with an extra quote + if (c == quote) quoted.append(quote); + quoted.append(c); + } + quoted.append(quote); + return quoted.toString(); + } + + /** + * Parse a multi-part (dot separated) SQL identifier from the + * String provided. Raise an excepion + * if the string does not contain valid SQL indentifiers. + * The returned String array contains the normalized form of the + * identifiers. + * + * @param s The string to be parsed + * @param quoteCharacter The character which frames a delimited id (e.g., " or `) + * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. + * @return An array of strings made by breaking the input string at its dots, '.'. + * @throws SparkException Invalid SQL identifier. + */ + public static String[] parseMultiPartSQLIdentifier( + String s, + char quoteCharacter, + boolean upperCaseIdentifiers) + throws SparkException { + StringReader r = new StringReader(s); + String[] qName = parseMultiPartSQLIdentifier( + s, + r, + quoteCharacter, + upperCaseIdentifiers); + verifyEmpty(s, r); + return qName; + } + + /** + * Parse a multi-part (dot separated) SQL identifier from the + * String provided. Raise an excepion + * if the string does not contain valid SQL indentifiers. + * The returned String array contains the normalized form of the + * identifiers. + * + * @param orig The full text being parsed + * @param r The multi-part identifier to be parsed + * @param quoteCharacter The character which frames a delimited id (e.g., " or `) + * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. + * @return An array of strings made by breaking the input string at its dots, '.'. + * @throws SparkException Invalid SQL identifier. + */ + private static String[] parseMultiPartSQLIdentifier( + String orig, + StringReader r, + char quoteCharacter, + boolean upperCaseIdentifiers) + throws SparkException { + Vector v = new Vector(); + while (true) { + String thisId = parseId(orig, r, quoteCharacter, upperCaseIdentifiers); + v.add(thisId); + int dot; + + try { + r.mark(0); + dot = r.read(); + if (dot != '.') { + if (dot != -1) r.reset(); + break; + } + } catch (IOException ioe) { + throw parseError(orig, ioe); + } + } + String[] result = new String[v.size()]; + v.copyInto(result); + return result; + } + + /** + * Read an id from the StringReader provided. + *

+ *

+ * Raise an exception if the first thing in the StringReader + * is not a valid id. + *

+ * + * @param orig The full text being parsed + * @param r The multi-part identifier to be parsed + * @param quoteCharacter The character which frames a delimited id (e.g., " or `) + * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. + * @throws SparkException Invalid SQL identifier. + */ + private static String parseId( + String orig, + StringReader r, + char quoteCharacter, + boolean upperCaseIdentifiers) + throws SparkException { + try { + r.mark(0); + int c = r.read(); + if (c == -1) { //id can't be 0-length + throw parseError(orig, null); + } + r.reset(); + if (c == quoteCharacter) { + return parseQId(orig, r, quoteCharacter); + } else { + return parseUnQId(orig, r, upperCaseIdentifiers); + } + } catch (IOException ioe) { + throw parseError(orig, ioe); + } + } + + /** + * Parse a regular identifier (unquoted) returning returning either + * the value of the identifier or a delimited identifier. Ensures + * that all characters in the identifer are valid for a regular identifier. + * + * @param orig The full text being parsed + * @param r Regular identifier to parse. + * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. + * @return the value of the identifer or a delimited identifier + * @throws SparkException Error accessing value + */ + private static String parseUnQId( + String orig, + StringReader r, + boolean upperCaseIdentifiers) + throws SparkException { + StringBuffer b = new StringBuffer(); + int c; + boolean first; + + try { + for (first = true; ; first = false) { + r.mark(0); + if (idChar(first, c = r.read())) { + b.append((char) c); + } else { + break; + } + } + + if (c != -1) { + r.reset(); + } + } catch (IOException ioe) { + throw parseError(orig, ioe); + } + + String id = b.toString(); + + return adjustCase(id, upperCaseIdentifiers); + } + + private static boolean idChar(boolean first, int c) { + if (((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) || + (!first && (c >= '0' && c <= '9')) || (!first && c == '_')) { + return true; + } else if (Character.isLetter((char) c)) { + return true; + } else if (!first && Character.isDigit((char) c)) { + return true; + } + + return false; + } + + /** + * Adjust the case of an unquoted identifier. + * Always use the java.util.ENGLISH locale + * + * @param s string to uppercase + * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. + * @return uppercased string + */ + private static String adjustCase(String s, boolean upperCaseIdentifiers) { + if ( upperCaseIdentifiers) { + return s.toUpperCase(Locale.ENGLISH); + } + else { + return s.toLowerCase(Locale.ENGLISH); + } + } + + /** + * Parse a delimited (quoted) identifier returning either + * the value of the identifier or a delimited identifier. + * + * @param orig The full text being parsed + * @param r Quoted identifier to parse. + * @return the value of the identifer or a delimited identifier + * @throws SparkException Error parsing identifier. + */ + private static String parseQId( + String orig, + StringReader r, + char quoteCharacter) + throws SparkException { + StringBuffer b = new StringBuffer(); + + try { + int c = r.read(); + if (c != quoteCharacter) { + throw parseError(orig, null); + } + + while (true) { + c = r.read(); + if (c == quoteCharacter) { + r.mark(0); + int c2 = r.read(); + if (c2 != quoteCharacter) { + if (c2 != -1) { + r.reset(); + } + break; + } + } else if (c == -1) { + throw parseError(orig, null); + } + + b.append((char) c); + } + } catch (IOException ioe) { + throw parseError(orig, ioe); + } + + if (b.length() == 0) { //id can't be 0-length + throw parseError(orig, null); + } + + return b.toString(); + } + + /** + * Verify the read is empty (no more characters in its stream). + * + * @param orig The full text being parsed + * @param r + * @throws SparkException + */ + private static void verifyEmpty(String orig, java.io.Reader r) + throws SparkException { + try { + if (r.read() != -1) { + throw parseError(orig, null); + } + } catch (IOException ioe) { + throw parseError(orig, ioe); + } + } + + /** + * Create a parsing exception. + * + * @param orig The full text being parsed + * @param cause Optional original exception + * @return A SparkException describing a parsing error. + */ + private static SparkException parseError(String orig, Exception cause) { + String message = "Error parsing SQL identifier: " + orig; + + if (cause != null) { + return new SparkException(message); + } else { + return new SparkException(message, cause); + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7887e559a302..1bce5d39a0ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.Properties +import org.apache.spark.sql.jdbc.JdbcDialects + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -255,6 +257,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) val conn = JdbcUtils.createConnection(url, props) + val dialect = JdbcDialects.get(url) try { var tableExists = JdbcUtils.tableExists(conn, url, table) @@ -268,13 +271,14 @@ final class DataFrameWriter private[sql](df: DataFrame) { } if (mode == SaveMode.Overwrite && tableExists) { - JdbcUtils.dropTable(conn, table) + JdbcUtils.dropTable(conn, dialect, table) tableExists = false } // Create the table if the table didn't exist. if (!tableExists) { val schema = JdbcUtils.schemaString(df, url) + dialect.vetSqlIdentifier(table) val sql = s"CREATE TABLE $table ($schema)" conn.prepareStatement(sql).executeUpdate() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e21..fc50b91a7622 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.util.Try import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -54,14 +54,20 @@ object JdbcUtils extends Logging { /** * Drops a table from the JDBC database. */ - def dropTable(conn: Connection, table: String): Unit = { + def dropTable(conn: Connection, dialect: JdbcDialect, table: String): Unit = { + dialect.vetSqlIdentifier(table) conn.prepareStatement(s"DROP TABLE $table").executeUpdate() } /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { + def insertStatement( + conn: Connection, + dialect: JdbcDialect, + table: String, + rddSchema: StructType): PreparedStatement = { + dialect.vetSqlIdentifier(table) val sql = new StringBuilder(s"INSERT INTO $table VALUES (") var fieldsLeft = rddSchema.fields.length while (fieldsLeft > 0) { @@ -88,6 +94,7 @@ object JdbcUtils extends Logging { */ def savePartition( getConnection: () => Connection, + dialect: JdbcDialect, table: String, iterator: Iterator[Row], rddSchema: StructType, @@ -97,7 +104,7 @@ object JdbcUtils extends Logging { var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(conn, dialect, table, rddSchema) try { var rowCount = 0 while (iterator.hasNext) { @@ -225,8 +232,10 @@ object JdbcUtils extends Logging { val driver: String = DriverRegistry.getDriverClassName(url) val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt + val jdbcDialect = JdbcDialects.get(url) df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) + savePartition(getConnection, jdbcDialect, table, iterator, + rddSchema, nullTypes, batchSize) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 88ae83957a70..d8787be68272 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.SparkException +import org.apache.spark.sql.SqlIdentifierUtil._ import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -62,6 +64,12 @@ abstract class JdbcDialect { */ def canHandle(url : String): Boolean + /** + * Return the character used to frame delimited identifiers in this database. + * @return The delimited id character (usually ", sometimes `) + */ + def quoteChar: Char = '"' + /** * Get the custom datatype mapping for the given jdbc meta information. * @param sqlType The sql type (see java.sql.Types) @@ -86,19 +94,63 @@ abstract class JdbcDialect { * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). */ def quoteIdentifier(colName: String): String = { - s""""$colName"""" + quoteString(colName, quoteChar) + } + + /** + * Get the SQL query that should be used to find if the given table exists. + * Call this method (and not tableExistsQuery) in order to verify + * that the table name is properly formed. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + * @throws org.apache.spark.SparkException On invalid table name. + */ + final def getTableExistsQuery(table: String): String = { + vetSqlIdentifier(table) + tableExistsQuery(table) } /** * Get the SQL query that should be used to find if the given table exists. Dialects can * override this method to return a query that works best in a particular database. + * Don't expose this method outside this class and its subclasses. + * Other consumers should call getTableExistsQuery instead. That method + * verifies that the table name is properly formed. * @param table The name of the table. * @return The SQL query to use for checking the table. */ - def getTableExistsQuery(table: String): String = { + protected def tableExistsQuery(table: String): String = { s"SELECT * FROM $table WHERE 1=0" } + /** Vet a user-supplied object id of the form + * [[catalog.]schema.]objectName + * by parsing it into a (catalog, schema, objectName) + * triple. The catalog and schema names may be empty. Raises + * a SparkException if the user-supplied id is malformed, + * e.g., is a string like "foo; drop database finance;" + * intended for a SQL injection attack. + * + * @param rawId The user-supplied object id (name). + * @throws org.apache.spark.SparkException On invalid ids. + */ + def vetSqlIdentifier(rawId: String) { + + // It's ok to assume that the database uppercases unquoted + // identifiers. That's because we aren't actually returning the + // parsed result to the user. The case-sensitivity of SQL + // identifiers is a tricky topic. See, for instance: + // https://github.com/ontop/ontop/wiki/Case-sensitivity-for-SQL-identifiers + val parsed : Array[String] = parseMultiPartSQLIdentifier(rawId, + quoteChar, true) + + parsed.length match { + case 1 => (null, null, parsed(0)) + case 2 => (null, parsed(0), parsed(1)) + case 3 => (parsed(0), parsed(1), parsed(2)) + case _ => throw new SparkException("Unparsable object id: " + rawId) + } + } } /** @@ -152,6 +204,12 @@ object JdbcDialects { case _ => new AggregatedDialect(matchingDialects) } } + + /** + * Get all dialects (useful for testing purposes). + */ + private[sql] def getAllDialects(): List[JdbcDialect] = dialects + } /** @@ -217,7 +275,7 @@ case object PostgresDialect extends JdbcDialect { case _ => None } - override def getTableExistsQuery(table: String): String = { + override protected def tableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } @@ -230,6 +288,7 @@ case object PostgresDialect extends JdbcDialect { @DeveloperApi case object MySQLDialect extends JdbcDialect { override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def quoteChar: Char = '`' override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { @@ -242,11 +301,7 @@ case object MySQLDialect extends JdbcDialect { } else None } - override def quoteIdentifier(colName: String): String = { - s"`$colName`" - } - - override def getTableExistsQuery(table: String): String = { + override protected def tableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce..b539cfd6d79b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,6 +25,7 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.SparkException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -484,4 +485,59 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) } + + /** + * Verify that the JdbcDialect rejects an illegal name. + * @param dialect The JdbcDialect. + * @param tableName A bad table name. + */ + def badNameVetter(dialect: JdbcDialect, tableName: String) { + val badTableName = tableName.replace( '"', dialect.quoteChar) + val dialectName = dialect.getClass.getName + try { + dialect.vetSqlIdentifier(badTableName) + fail(dialectName + " should have rejected " + badTableName) + } catch { + case exc: SparkException => { + val expectedMessage = "Error parsing SQL identifier: " + badTableName + assert(expectedMessage == exc.getMessage) + } + case obj: Throwable => { + val badResponse = obj.toString + fail("Unexpected vetting failure when dialect " + dialectName + + " vets table name " + badTableName + ": " + badResponse) + } + } + } + + test("verify that JDBC dialects vet table names") { + val badNames = + """bad"name""" :: + """foo; drop database finance;""" :: + Nil + val allDialects = JdbcDialects.getAllDialects() + for ( d <- allDialects; b <- badNames ) badNameVetter(d, b) + + val goodNames = + """foo.bar""" :: + """"foo".bar""" :: + """"foo"."bar"""" :: + """foo."bar"""" :: + """"foo.bar"""" :: + Nil + for ( d <- allDialects; g <- goodNames ) { + val goodTableName = g.replace( '"', d.quoteChar) + try { + d.getTableExistsQuery(goodTableName) + } catch { + case obj: Throwable => { + val errorMessage = "Dang. " + d.getClass.getName + + " couldn't handle " + + goodTableName + ": " + obj.toString + fail(errorMessage) + } + } + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133..c795525b4b88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties +import org.apache.spark.SparkException import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{Row, SaveMode} @@ -151,4 +152,28 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("Negative: CREATE with illegal table name") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + df.write.jdbc(url, "TEST.DUMMY", new Properties) + assert(2 === sqlContext.read.jdbc(url, "TEST.DUMMY", new Properties).count) + + try { + df.write.jdbc(url, "TEST.FOO(A INT); DROP TABLE TEST.DUMMY;", new Properties) + fail("Table creation should have failed.") + } catch { + case exc: SparkException => { + val expectedMessage = "Error parsing SQL identifier: TEST.FOO(A INT); DROP TABLE TEST.DUMMY;" + assert(expectedMessage == exc.getMessage) + } + case obj: Throwable => { + val badResponse = obj.toString + fail("Unexpected failure for table creation: " + badResponse) + } + } + assert(2 === sqlContext.read.jdbc(url, "TEST.DUMMY", new Properties).count) + } + + } From 25481751cc08f0e903ac20f960c6891f932d754c Mon Sep 17 00:00:00 2001 From: Rick Hillegas Date: Wed, 21 Oct 2015 11:46:26 -0700 Subject: [PATCH 2/6] SPARK-10857: Add and remove some comments. --- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index d8787be68272..d3c275466e8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -128,19 +128,14 @@ abstract class JdbcDialect { * by parsing it into a (catalog, schema, objectName) * triple. The catalog and schema names may be empty. Raises * a SparkException if the user-supplied id is malformed, - * e.g., is a string like "foo; drop database finance;" - * intended for a SQL injection attack. + * e.g., is a string like "foo; drop database finance;", + * something intended for a SQL injection attack. * * @param rawId The user-supplied object id (name). * @throws org.apache.spark.SparkException On invalid ids. */ def vetSqlIdentifier(rawId: String) { - // It's ok to assume that the database uppercases unquoted - // identifiers. That's because we aren't actually returning the - // parsed result to the user. The case-sensitivity of SQL - // identifiers is a tricky topic. See, for instance: - // https://github.com/ontop/ontop/wiki/Case-sensitivity-for-SQL-identifiers val parsed : Array[String] = parseMultiPartSQLIdentifier(rawId, quoteChar, true) From 2865416af96c2921ab44b79c69868bbf945c2b29 Mon Sep 17 00:00:00 2001 From: Rick Hillegas Date: Wed, 21 Oct 2015 11:46:56 -0700 Subject: [PATCH 3/6] SPARK-10857: Add and remove some comments. --- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index d3c275466e8d..8f482ae26ab0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -299,6 +299,17 @@ case object MySQLDialect extends JdbcDialect { override protected def tableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } + + // The default implementation of this method allows embedded, + // escaped quotes inside quoted identifiers. SQL Server does not + // allow embedded quotes. This means that this method won't catch + // some illegal table names. Those names will appear to SQL Server as an + // ungrammatical sequence of quoted identifiers. In order to get + // a better error message, someone may want to provide an + // implementation which handles the SQL Server grammar better. + // + //override def vetSqlIdentifier(rawId: String) + } /** From 287e8f0c0e6424d34cc0a0f4fe8ab85e345a1443 Mon Sep 17 00:00:00 2001 From: Rick Hillegas Date: Wed, 21 Oct 2015 11:50:56 -0700 Subject: [PATCH 4/6] DERBY-10857: Fix a scalastyle problem. --- .../src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 8f482ae26ab0..371f9e67e0ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -308,7 +308,7 @@ case object MySQLDialect extends JdbcDialect { // a better error message, someone may want to provide an // implementation which handles the SQL Server grammar better. // - //override def vetSqlIdentifier(rawId: String) + // override def vetSqlIdentifier(rawId: String) } From 13538203eedb52054f3c8ed481fbfc7161a237a0 Mon Sep 17 00:00:00 2001 From: Rick Hillegas Date: Wed, 21 Oct 2015 15:33:41 -0700 Subject: [PATCH 5/6] SPARK-10857: Fix a long line style problem. --- .../test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index c795525b4b88..0430c2b50d9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -164,7 +164,8 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { fail("Table creation should have failed.") } catch { case exc: SparkException => { - val expectedMessage = "Error parsing SQL identifier: TEST.FOO(A INT); DROP TABLE TEST.DUMMY;" + val expectedMessage = + "Error parsing SQL identifier: TEST.FOO(A INT); DROP TABLE TEST.DUMMY;" assert(expectedMessage == exc.getMessage) } case obj: Throwable => { From 6c906bb21a04a99ddf992863f552ed9e0b36d574 Mon Sep 17 00:00:00 2001 From: Rick Hillegas Date: Tue, 27 Oct 2015 11:48:38 -0700 Subject: [PATCH 6/6] SPARK-10857: Replace the Java id parser with a Scala version which is based on regular expressions; introduce a case object to represent a 3-part table identifier. --- .../apache/spark/sql/SqlIdentifierUtil.java | 322 ------------------ .../org/apache/spark/sql/SqlIdUtil.scala | 241 +++++++++++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 13 +- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 12 + 4 files changed, 256 insertions(+), 332 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/SqlIdentifierUtil.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/SqlIdUtil.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/SqlIdentifierUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/SqlIdentifierUtil.java deleted file mode 100644 index ff05c53b6693..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/SqlIdentifierUtil.java +++ /dev/null @@ -1,322 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import java.io.IOException; -import java.io.StringReader; -import java.util.Locale; -import java.util.Vector; -import org.apache.spark.SparkException; - -/** - * Methods for handling SQL identifiers. These methods were cribbed - * from org.apache.derby.iapi.util.IdUtil and - * org.apache.derby.iapi.util.StringUtil. - */ -public class SqlIdentifierUtil { - - /** - * Quote a string so that it can be used as an identifier or a string - * literal in SQL statements. Identifiers are surrounded by double quotes - * and string literals are surrounded by single quotes. If the string - * contains quote characters, they are escaped. - * - * @param source the string to quote - * @param quote the character to quote the string with (' or ") - * @return a string quoted with the specified quote character - */ - public static String quoteString(String source, char quote) { - // Normally, the quoted string is two characters longer than the source - // string (because of start quote and end quote). - StringBuffer quoted = new StringBuffer(source.length() + 2); - quoted.append(quote); - for (int i = 0; i < source.length(); i++) { - char c = source.charAt(i); - // if the character is a quote, escape it with an extra quote - if (c == quote) quoted.append(quote); - quoted.append(c); - } - quoted.append(quote); - return quoted.toString(); - } - - /** - * Parse a multi-part (dot separated) SQL identifier from the - * String provided. Raise an excepion - * if the string does not contain valid SQL indentifiers. - * The returned String array contains the normalized form of the - * identifiers. - * - * @param s The string to be parsed - * @param quoteCharacter The character which frames a delimited id (e.g., " or `) - * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. - * @return An array of strings made by breaking the input string at its dots, '.'. - * @throws SparkException Invalid SQL identifier. - */ - public static String[] parseMultiPartSQLIdentifier( - String s, - char quoteCharacter, - boolean upperCaseIdentifiers) - throws SparkException { - StringReader r = new StringReader(s); - String[] qName = parseMultiPartSQLIdentifier( - s, - r, - quoteCharacter, - upperCaseIdentifiers); - verifyEmpty(s, r); - return qName; - } - - /** - * Parse a multi-part (dot separated) SQL identifier from the - * String provided. Raise an excepion - * if the string does not contain valid SQL indentifiers. - * The returned String array contains the normalized form of the - * identifiers. - * - * @param orig The full text being parsed - * @param r The multi-part identifier to be parsed - * @param quoteCharacter The character which frames a delimited id (e.g., " or `) - * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. - * @return An array of strings made by breaking the input string at its dots, '.'. - * @throws SparkException Invalid SQL identifier. - */ - private static String[] parseMultiPartSQLIdentifier( - String orig, - StringReader r, - char quoteCharacter, - boolean upperCaseIdentifiers) - throws SparkException { - Vector v = new Vector(); - while (true) { - String thisId = parseId(orig, r, quoteCharacter, upperCaseIdentifiers); - v.add(thisId); - int dot; - - try { - r.mark(0); - dot = r.read(); - if (dot != '.') { - if (dot != -1) r.reset(); - break; - } - } catch (IOException ioe) { - throw parseError(orig, ioe); - } - } - String[] result = new String[v.size()]; - v.copyInto(result); - return result; - } - - /** - * Read an id from the StringReader provided. - *

- *

- * Raise an exception if the first thing in the StringReader - * is not a valid id. - *

- * - * @param orig The full text being parsed - * @param r The multi-part identifier to be parsed - * @param quoteCharacter The character which frames a delimited id (e.g., " or `) - * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. - * @throws SparkException Invalid SQL identifier. - */ - private static String parseId( - String orig, - StringReader r, - char quoteCharacter, - boolean upperCaseIdentifiers) - throws SparkException { - try { - r.mark(0); - int c = r.read(); - if (c == -1) { //id can't be 0-length - throw parseError(orig, null); - } - r.reset(); - if (c == quoteCharacter) { - return parseQId(orig, r, quoteCharacter); - } else { - return parseUnQId(orig, r, upperCaseIdentifiers); - } - } catch (IOException ioe) { - throw parseError(orig, ioe); - } - } - - /** - * Parse a regular identifier (unquoted) returning returning either - * the value of the identifier or a delimited identifier. Ensures - * that all characters in the identifer are valid for a regular identifier. - * - * @param orig The full text being parsed - * @param r Regular identifier to parse. - * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. - * @return the value of the identifer or a delimited identifier - * @throws SparkException Error accessing value - */ - private static String parseUnQId( - String orig, - StringReader r, - boolean upperCaseIdentifiers) - throws SparkException { - StringBuffer b = new StringBuffer(); - int c; - boolean first; - - try { - for (first = true; ; first = false) { - r.mark(0); - if (idChar(first, c = r.read())) { - b.append((char) c); - } else { - break; - } - } - - if (c != -1) { - r.reset(); - } - } catch (IOException ioe) { - throw parseError(orig, ioe); - } - - String id = b.toString(); - - return adjustCase(id, upperCaseIdentifiers); - } - - private static boolean idChar(boolean first, int c) { - if (((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) || - (!first && (c >= '0' && c <= '9')) || (!first && c == '_')) { - return true; - } else if (Character.isLetter((char) c)) { - return true; - } else if (!first && Character.isDigit((char) c)) { - return true; - } - - return false; - } - - /** - * Adjust the case of an unquoted identifier. - * Always use the java.util.ENGLISH locale - * - * @param s string to uppercase - * @param upperCaseIdentifiers True if SQL ids are normalized to upper case. - * @return uppercased string - */ - private static String adjustCase(String s, boolean upperCaseIdentifiers) { - if ( upperCaseIdentifiers) { - return s.toUpperCase(Locale.ENGLISH); - } - else { - return s.toLowerCase(Locale.ENGLISH); - } - } - - /** - * Parse a delimited (quoted) identifier returning either - * the value of the identifier or a delimited identifier. - * - * @param orig The full text being parsed - * @param r Quoted identifier to parse. - * @return the value of the identifer or a delimited identifier - * @throws SparkException Error parsing identifier. - */ - private static String parseQId( - String orig, - StringReader r, - char quoteCharacter) - throws SparkException { - StringBuffer b = new StringBuffer(); - - try { - int c = r.read(); - if (c != quoteCharacter) { - throw parseError(orig, null); - } - - while (true) { - c = r.read(); - if (c == quoteCharacter) { - r.mark(0); - int c2 = r.read(); - if (c2 != quoteCharacter) { - if (c2 != -1) { - r.reset(); - } - break; - } - } else if (c == -1) { - throw parseError(orig, null); - } - - b.append((char) c); - } - } catch (IOException ioe) { - throw parseError(orig, ioe); - } - - if (b.length() == 0) { //id can't be 0-length - throw parseError(orig, null); - } - - return b.toString(); - } - - /** - * Verify the read is empty (no more characters in its stream). - * - * @param orig The full text being parsed - * @param r - * @throws SparkException - */ - private static void verifyEmpty(String orig, java.io.Reader r) - throws SparkException { - try { - if (r.read() != -1) { - throw parseError(orig, null); - } - } catch (IOException ioe) { - throw parseError(orig, ioe); - } - } - - /** - * Create a parsing exception. - * - * @param orig The full text being parsed - * @param cause Optional original exception - * @return A SparkException describing a parsing error. - */ - private static SparkException parseError(String orig, Exception cause) { - String message = "Error parsing SQL identifier: " + orig; - - if (cause != null) { - return new SparkException(message); - } else { - return new SparkException(message, cause); - } - } - -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/SqlIdUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/SqlIdUtil.scala new file mode 100644 index 000000000000..5c9178131d22 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/SqlIdUtil.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Locale +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkException + +/** + * A three part table identifier. The first two parts can be null. + * + * @param database The database name. + * @param schema The schema name. + * @param table The table name. + */ +case class TableId(database: String, schema: String, table: String) + +/** + * Utility methods for SQL identifiers. These methods were loosely + * translated from org.apache.derby.iapi.util.IdUtil and + * org.apache.derby.iapi.util.StringUtil. + */ +object SqlIdUtil { + + private val OneQuote = """"""" + private val TwoQuotes = """""""" + private val DefaultQuote = '"' + + // Regular expression defining one id in a dot-separated SQL identifier chain + private val OneIdString = + "(\\s)*((" + // leading spaces ok + """\p{Alpha}(\p{Alnum}|_)*""" + // regular identifier (no quotes) + ")|(" + // or + """"(""|[^"])+"""" + // delimited identifier (quoted) + "))(\\s)*" // trailing spaces ok + + /** + * Quote a string so that it can be used as an identifier or a string + * literal in SQL statements. Identifiers are usually surrounded by double quotes + * and string literals are surrounded by single quotes. If the string + * contains quote characters, they are escaped. + * + * @param source the string to quote + * @param quote the framing quote character (e.g.: ', ", `) + * @return a string quoted with the indicated quote character + */ + def quoteString(source: String, quote: Char): String = { + // Normally, the quoted string is two characters longer than the source + // string (because of start quote and end quote). + val quoted = new StringBuilder(source.length() + 2) + + quoted.append(quote) + for (ch <- source) { + quoted.append(ch) + if (ch == quote) quoted.append(quote) + } + quoted.append(quote) + quoted.toString() + } + + /** Parse a user-supplied object id of the form + * [[database.]schema.]objectName + * into a TableIdentifier(database, schema, objectName). + * The database and schema names may be empty. The caller + * must supply the database-specific quote character which is used + * to frame delimited ids. For most databases this is the " + * character. For Hive, this is the ` character. The caller must + * specify whether the database uppercases or lowercases + * unquoted identifiers when they are stored in its metadata + * catalogs. + * + * The fields of the TableIdentifier are normalized to the case + * convention used by the database's catalogs. So for a database + * which uses " for quoted identifiers and which uppercases + * ids in its metadata catalogs, the string + * + * "foo".bar + * + * would result in + * + * TableIdentifier( null, foo, BAR ) + * + * @param rawName The user-supplied name. + * @param quote The db-specific character which frames delimited ids. + * @param upperCase True if the db uppercases un-delimited ids. + */ + def parseSqlIds( + rawName: String, + quote: Char, + upperCase: Boolean): TableId = { + val parsed = parseMultiPartSqlIdentifier(rawName, + quote, upperCase) + + parsed.length match { + case 1 => TableId(null, null, parsed(0)) + case 2 => TableId(null, parsed(0), parsed(1)) + case 3 => TableId(parsed(0), parsed(1), parsed(2)) + case _ => throw new Exception("Unparsable object id: " + rawName) + } + } + + /** + * Parse a multi-part (dot separated) chain of SQL identifiers from the + * String provided. Raise an excepion + * if the string does not contain valid SQL indentifiers. + * The returned String array contains the normalized form of the + * identifiers. + * + * @param rawName The string to be parsed + * @param quote The character which frames a delimited id (e.g., " or `) + * @param upperCase True if SQL ids are normalized to upper case. + * @return An array of strings made by breaking the input string at its dots, '.'. + * @throws SparkException Invalid SQL identifier. + */ + private def parseMultiPartSqlIdentifier( + rawName: String, + quote: Char, + upperCase: Boolean): ArrayBuffer[String] = { + + // construct the regex, accounting for the caller-supplied quote character + var regexString = OneIdString + if (quote != DefaultQuote) + { + regexString = regexString.replace(DefaultQuote, quote) + } + val oneIdRegex = regexString.r + + // + // Loop through the raw string, one identifier at a time. + // Discard spaces around the identifiers. Discard + // the dots which separate one identifier from the next. + // + var result = ArrayBuffer[String]() + var keepGoing = true + var remainingString = rawName + while (keepGoing) + { + oneIdRegex.findPrefixOf(remainingString) match { + + case Some(paddedId) => { + val paddedIdLength = paddedId.length + result.append(normalize(paddedId.trim, quote, upperCase)) + if (remainingString.length == paddedIdLength) { + keepGoing = false // we're done. hooray. + } + else if (remainingString.charAt(paddedIdLength) == '.') { + // chop off the old identifier and the dot separator. + // continue looking for more ids in the rest of the string. + remainingString = remainingString.substring(paddedIdLength + 1) + } + else { + throw parseError(rawName) + } + } + + case _ => { + throw parseError(rawName) + } + } // end matching an id + + } // end of loop through ids + + result + } + + /** + * Normalize a SQL identifier to the case used by the target + * database's metadata catalogs. + * + * @param rawName The string to be normalized (may be framed by quotes) + * @param quote The character which frames a delimited id (e.g., " or `) + * @param upperCase True if SQL ids are normalized to upper case. + * @return An array of strings made by breaking the input string at its dots, '.'. + */ + private def normalize(rawName: String, quote: Char, upperCase: Boolean): String = { + + // regular id + if (rawName.charAt(0) != quote) adjustCase(rawName, upperCase) + // delimited id + else stripQuotes(rawName, quote) + } + + /** + * Adjust the case of an unquoted identifier to the case convention + * used by the metadata catalogs of the target database. + * Always use the java.util.ENGLISH locale. + * + * @param rawName string to uppercase + * @param upperCase True if SQL ids are normalized to upper case. + * @return The properly cased string. + */ + private def adjustCase(rawName: String, upperCase: Boolean): String = { + if (upperCase) rawName.toUpperCase(Locale.ENGLISH) + else rawName.toLowerCase(Locale.ENGLISH) + } + + /** + * Strip framing quotes from a delimited id and un-escape interior quotes. + * + * @param rawName string to uppercase + * @param quote the database-specific quote character. + * @return The properly cased string. + */ + private def stripQuotes(rawName: String, quote: Char): String = { + var oneQuote = OneQuote + var twoQuotes = TwoQuotes + if ( quote != DefaultQuote) + { + val oneQuote = OneQuote.replace(DefaultQuote, quote) + val twoQuotes = TwoQuotes.replace(DefaultQuote, quote) + } + rawName.substring(1, rawName.length - 1).replace(twoQuotes, oneQuote) + } + + /** + * Create a parsing exception. + * + * @param orig The full text being parsed + * @return A SparkException describing a parsing error. + */ + private def parseError(orig: String): SparkException = { + new SparkException("Error parsing SQL identifier: " + orig) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 371f9e67e0ea..876d782c0c90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types import org.apache.spark.SparkException -import org.apache.spark.sql.SqlIdentifierUtil._ +import org.apache.spark.sql.SqlIdUtil._ import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -136,15 +136,8 @@ abstract class JdbcDialect { */ def vetSqlIdentifier(rawId: String) { - val parsed : Array[String] = parseMultiPartSQLIdentifier(rawId, - quoteChar, true) - - parsed.length match { - case 1 => (null, null, parsed(0)) - case 2 => (null, parsed(0), parsed(1)) - case 3 => (parsed(0), parsed(1), parsed(2)) - case _ => throw new SparkException("Unparsable object id: " + rawId) - } + // raises a SparkException if the string doesn't parse + parseSqlIds(rawId, quoteChar, true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index b539cfd6d79b..6579bb0fe5a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -514,6 +514,10 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val badNames = """bad"name""" :: """foo; drop database finance;""" :: + """foo.""" :: + """.foo""" :: + """foo bar""" :: + """fo"o""" :: Nil val allDialects = JdbcDialects.getAllDialects() for ( d <- allDialects; b <- badNames ) badNameVetter(d, b) @@ -524,6 +528,14 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext """"foo"."bar"""" :: """foo."bar"""" :: """"foo.bar"""" :: + """ foo""" :: + """"foo"""" :: + """"foo bar"""" :: + """"foo."""" :: + """foo.bar""" :: + """foo .bar""" :: + """foo.bar.wibble""" :: + """"fo""o"""" :: Nil for ( d <- allDialects; g <- goodNames ) { val goodTableName = g.replace( '"', d.quoteChar)