Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/SqlIdUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.util.Locale
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkException

/**
* A three part table identifier. The first two parts can be null.
*
* @param database The database name.
* @param schema The schema name.
* @param table The table name.
*/
case class TableId(database: String, schema: String, table: String)

/**
* Utility methods for SQL identifiers. These methods were loosely
* translated from org.apache.derby.iapi.util.IdUtil and
* org.apache.derby.iapi.util.StringUtil.
*/
object SqlIdUtil {

private val OneQuote = """""""
private val TwoQuotes = """"""""
private val DefaultQuote = '"'

// Regular expression defining one id in a dot-separated SQL identifier chain
private val OneIdString =
"(\\s)*((" + // leading spaces ok
"""\p{Alpha}(\p{Alnum}|_)*""" + // regular identifier (no quotes)
")|(" + // or
""""(""|[^"])+"""" + // delimited identifier (quoted)
"))(\\s)*" // trailing spaces ok

/**
* Quote a string so that it can be used as an identifier or a string
* literal in SQL statements. Identifiers are usually surrounded by double quotes
* and string literals are surrounded by single quotes. If the string
* contains quote characters, they are escaped.
*
* @param source the string to quote
* @param quote the framing quote character (e.g.: ', ", `)
* @return a string quoted with the indicated quote character
*/
def quoteString(source: String, quote: Char): String = {
// Normally, the quoted string is two characters longer than the source
// string (because of start quote and end quote).
val quoted = new StringBuilder(source.length() + 2)

quoted.append(quote)
for (ch <- source) {
quoted.append(ch)
if (ch == quote) quoted.append(quote)
}
quoted.append(quote)
quoted.toString()
}

/** Parse a user-supplied object id of the form
* [[database.]schema.]objectName
* into a TableIdentifier(database, schema, objectName).
* The database and schema names may be empty. The caller
* must supply the database-specific quote character which is used
* to frame delimited ids. For most databases this is the "
* character. For Hive, this is the ` character. The caller must
* specify whether the database uppercases or lowercases
* unquoted identifiers when they are stored in its metadata
* catalogs.
*
* The fields of the TableIdentifier are normalized to the case
* convention used by the database's catalogs. So for a database
* which uses " for quoted identifiers and which uppercases
* ids in its metadata catalogs, the string
*
* "foo".bar
*
* would result in
*
* TableIdentifier( null, foo, BAR )
*
* @param rawName The user-supplied name.
* @param quote The db-specific character which frames delimited ids.
* @param upperCase True if the db uppercases un-delimited ids.
*/
def parseSqlIds(
rawName: String,
quote: Char,
upperCase: Boolean): TableId = {
val parsed = parseMultiPartSqlIdentifier(rawName,
quote, upperCase)

parsed.length match {
case 1 => TableId(null, null, parsed(0))
case 2 => TableId(null, parsed(0), parsed(1))
case 3 => TableId(parsed(0), parsed(1), parsed(2))
case _ => throw new Exception("Unparsable object id: " + rawName)
}
}

/**
* Parse a multi-part (dot separated) chain of SQL identifiers from the
* String provided. Raise an excepion
* if the string does not contain valid SQL indentifiers.
* The returned String array contains the normalized form of the
* identifiers.
*
* @param rawName The string to be parsed
* @param quote The character which frames a delimited id (e.g., " or `)
* @param upperCase True if SQL ids are normalized to upper case.
* @return An array of strings made by breaking the input string at its dots, '.'.
* @throws SparkException Invalid SQL identifier.
*/
private def parseMultiPartSqlIdentifier(
rawName: String,
quote: Char,
upperCase: Boolean): ArrayBuffer[String] = {

// construct the regex, accounting for the caller-supplied quote character
var regexString = OneIdString
if (quote != DefaultQuote)
{
regexString = regexString.replace(DefaultQuote, quote)
}
val oneIdRegex = regexString.r

//
// Loop through the raw string, one identifier at a time.
// Discard spaces around the identifiers. Discard
// the dots which separate one identifier from the next.
//
var result = ArrayBuffer[String]()
var keepGoing = true
var remainingString = rawName
while (keepGoing)
{
oneIdRegex.findPrefixOf(remainingString) match {

case Some(paddedId) => {
val paddedIdLength = paddedId.length
result.append(normalize(paddedId.trim, quote, upperCase))
if (remainingString.length == paddedIdLength) {
keepGoing = false // we're done. hooray.
}
else if (remainingString.charAt(paddedIdLength) == '.') {
// chop off the old identifier and the dot separator.
// continue looking for more ids in the rest of the string.
remainingString = remainingString.substring(paddedIdLength + 1)
}
else {
throw parseError(rawName)
}
}

case _ => {
throw parseError(rawName)
}
} // end matching an id

} // end of loop through ids

result
}

/**
* Normalize a SQL identifier to the case used by the target
* database's metadata catalogs.
*
* @param rawName The string to be normalized (may be framed by quotes)
* @param quote The character which frames a delimited id (e.g., " or `)
* @param upperCase True if SQL ids are normalized to upper case.
* @return An array of strings made by breaking the input string at its dots, '.'.
*/
private def normalize(rawName: String, quote: Char, upperCase: Boolean): String = {

// regular id
if (rawName.charAt(0) != quote) adjustCase(rawName, upperCase)
// delimited id
else stripQuotes(rawName, quote)
}

/**
* Adjust the case of an unquoted identifier to the case convention
* used by the metadata catalogs of the target database.
* Always use the java.util.ENGLISH locale.
*
* @param rawName string to uppercase
* @param upperCase True if SQL ids are normalized to upper case.
* @return The properly cased string.
*/
private def adjustCase(rawName: String, upperCase: Boolean): String = {
if (upperCase) rawName.toUpperCase(Locale.ENGLISH)
else rawName.toLowerCase(Locale.ENGLISH)
}

/**
* Strip framing quotes from a delimited id and un-escape interior quotes.
*
* @param rawName string to uppercase
* @param quote the database-specific quote character.
* @return The properly cased string.
*/
private def stripQuotes(rawName: String, quote: Char): String = {
var oneQuote = OneQuote
var twoQuotes = TwoQuotes
if ( quote != DefaultQuote)
{
val oneQuote = OneQuote.replace(DefaultQuote, quote)
val twoQuotes = TwoQuotes.replace(DefaultQuote, quote)
}
rawName.substring(1, rawName.length - 1).replace(twoQuotes, oneQuote)
}

/**
* Create a parsing exception.
*
* @param orig The full text being parsed
* @return A SparkException describing a parsing error.
*/
private def parseError(orig: String): SparkException = {
new SparkException("Error parsing SQL identifier: " + orig)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql

import java.util.Properties

import org.apache.spark.sql.jdbc.JdbcDialects

import scala.collection.JavaConverters._

import org.apache.spark.annotation.Experimental
Expand Down Expand Up @@ -255,6 +257,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
// connectionProperties should override settings in extraOptions
props.putAll(connectionProperties)
val conn = JdbcUtils.createConnection(url, props)
val dialect = JdbcDialects.get(url)

try {
var tableExists = JdbcUtils.tableExists(conn, url, table)
Expand All @@ -268,13 +271,14 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}

if (mode == SaveMode.Overwrite && tableExists) {
JdbcUtils.dropTable(conn, table)
JdbcUtils.dropTable(conn, dialect, table)
tableExists = false
}

// Create the table if the table didn't exist.
if (!tableExists) {
val schema = JdbcUtils.schemaString(df, url)
dialect.vetSqlIdentifier(table)
val sql = s"CREATE TABLE $table ($schema)"
conn.prepareStatement(sql).executeUpdate()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.util.Properties
import scala.util.Try

import org.apache.spark.Logging
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}

Expand Down Expand Up @@ -54,14 +54,20 @@ object JdbcUtils extends Logging {
/**
* Drops a table from the JDBC database.
*/
def dropTable(conn: Connection, table: String): Unit = {
def dropTable(conn: Connection, dialect: JdbcDialect, table: String): Unit = {
dialect.vetSqlIdentifier(table)
conn.prepareStatement(s"DROP TABLE $table").executeUpdate()
}

/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
def insertStatement(
conn: Connection,
dialect: JdbcDialect,
table: String,
rddSchema: StructType): PreparedStatement = {
dialect.vetSqlIdentifier(table)
val sql = new StringBuilder(s"INSERT INTO $table VALUES (")
var fieldsLeft = rddSchema.fields.length
while (fieldsLeft > 0) {
Expand All @@ -88,6 +94,7 @@ object JdbcUtils extends Logging {
*/
def savePartition(
getConnection: () => Connection,
dialect: JdbcDialect,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
Expand All @@ -97,7 +104,7 @@ object JdbcUtils extends Logging {
var committed = false
try {
conn.setAutoCommit(false) // Everything in the same db transaction.
val stmt = insertStatement(conn, table, rddSchema)
val stmt = insertStatement(conn, dialect, table, rddSchema)
try {
var rowCount = 0
while (iterator.hasNext) {
Expand Down Expand Up @@ -225,8 +232,10 @@ object JdbcUtils extends Logging {
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
val jdbcDialect = JdbcDialects.get(url)
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
savePartition(getConnection, jdbcDialect, table, iterator,
rddSchema, nullTypes, batchSize)
}
}

Expand Down
Loading