diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 96ff389faf4a..c5122843ff20 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -274,6 +274,7 @@ exportMethods("%<=>%", "floor", "format_number", "format_string", + "from_csv", "from_json", "from_unixtime", "from_utc_timestamp", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 572dee50127b..d578c41e0fb0 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -188,6 +188,7 @@ NULL #' \item \code{to_json}: it is the column containing the struct, array of the structs, #' the map or array of maps. #' \item \code{from_json}: it is the column containing the JSON string. +#' \item \code{from_csv}: it is the column containing the CSV string. #' } #' @param y Column to compute on. #' @param value A value to compute on. @@ -196,10 +197,17 @@ NULL #' \item \code{array_position}: a value to locate in the given array. #' \item \code{array_remove}: a value to remove in the given array. #' } -#' @param ... additional argument(s). In \code{to_json} and \code{from_json}, this contains -#' additional named properties to control how it is converted, accepts the same -#' options as the JSON data source. In \code{arrays_zip}, this contains additional -#' Columns of arrays to be merged. +#' @param schema +#' \itemize{ +#' \item \code{from_json}: a structType object to use as the schema to use +#' when parsing the JSON string. Since Spark 2.3, the DDL-formatted string is +#' also supported for the schema. +#' \item \code{from_csv}: a DDL-formatted string +#' } +#' @param ... additional argument(s). In \code{to_json}, \code{from_json} and \code{from_csv}, +#' this contains additional named properties to control how it is converted, accepts +#' the same options as the JSON and CSV data source. In \code{arrays_zip}, +#' this contains additional Columns of arrays to be merged. #' @name column_collection_functions #' @rdname column_collection_functions #' @family collection functions @@ -2164,8 +2172,6 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' to \code{TRUE}. If the string is unparseable, the Column will contain the value NA. #' #' @rdname column_collection_functions -#' @param schema a structType object to use as the schema to use when parsing the JSON string. -#' Since Spark 2.3, the DDL-formatted string is also supported for the schema. #' @param as.json.array indicating if input string is JSON array of objects or a single object. #' @aliases from_json from_json,Column,characterOrstructType-method #' @examples @@ -2202,6 +2208,23 @@ setMethod("from_json", signature(x = "Column", schema = "characterOrstructType") column(jc) }) +#' @details +#' \code{from_csv}: Parses a column containing a CSV string into a Column of \code{structType} +#' with the specified \code{schema}. +#' If the string is unparseable, the Column will contain the value NA. +#' +#' @rdname column_collection_functions +#' @aliases from_csv from_csv,Column,character-method +#' @note from_csv since 3.0.0 +setMethod("from_csv", signature(x = "Column", schema = "character"), + function(x, schema, ...) { + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", + "from_csv", + x@jc, schema, options) + column(jc) + }) + #' @details #' \code{from_utc_timestamp}: Given a timestamp like '2017-07-14 02:40:00.0', interprets it as a #' time in UTC, and renders that time as a timestamp in the given time zone. For example, 'GMT+1' diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 27c1b312d645..9d8ebc8f6936 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -984,6 +984,10 @@ setGeneric("format_string", function(format, x, ...) { standardGeneric("format_s #' @name NULL setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("from_csv", function(x, schema, ...) { standardGeneric("from_csv") }) + #' @rdname column_datetime_functions #' @name NULL setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 40d8f8084f2f..004a848fff87 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1659,6 +1659,11 @@ test_that("column functions", { expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) + # Test from_csv() + df <- as.DataFrame(list(list("col" = "1"))) + c <- collect(select(df, alias(from_csv(df$col, "a INT"), "csv"))) + expect_equal(c[[1]][[1]]$a, 1) + # Test to_json(), from_json() df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") j <- collect(select(df, alias(to_json(df$people), "json"))) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6da5237d18de..60cc9d70a6fe 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2348,6 +2348,27 @@ def schema_of_json(col): return Column(jc) +@ignore_unicode_prefix +@since(2.5) +def schema_of_csv(col, options={}): + """ + Parses a column containing a CSV string and infers its schema in DDL format. + + :param col: string column in CSV format + :param options: options to control parsing. accepts the same options as the CSV datasource + + >>> from pyspark.sql.types import * + >>> data = [(1, '1|a')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_csv(df.value, {'sep':'|'}).alias("csv")).collect() + [Row(csv=u'struct<_c0:int,_c1:string>')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_csv(_to_java_column(col), options) + return Column(jc) + + @since(1.5) def size(col): """ @@ -2637,6 +2658,29 @@ def sequence(start, stop, step=None): _to_java_column(start), _to_java_column(stop), _to_java_column(step))) +@ignore_unicode_prefix +@since(3.0) +def from_csv(col, schema, options={}): + """ + Parses a column containing a CSV string into a :class:`StructType` + with the specified schema. Returns `null`, in the case of an unparseable string. + + :param col: string column in CSV format + :param schema: a string with schema in DDL format to use when parsing the CSV column. + :param options: options to control parsing. accepts the same options as the CSV datasource + + >>> from pyspark.sql.types import * + >>> data = [(1, '1')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_csv(df.value, "a INT").alias("csv")).collect() + [Row(csv=Row(a=1))] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options) + return Column(jc) + + # ---------------------------- User Defined Function ---------------------------------- class PandasUDFType(object): diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 224c70ce24d6..af83dbff2383 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -103,6 +103,12 @@ commons-codec commons-codec + + com.univocity + univocity-parsers + 2.7.3 + jar + target/scala-${scala.binary.version}/classes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 8b69a4703696..b728fa33c7f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -529,7 +529,11 @@ object FunctionRegistry { castAlias("date", DateType), castAlias("timestamp", TimestampType), castAlias("binary", BinaryType), - castAlias("string", StringType) + castAlias("string", StringType), + + // csv + expression[CsvToStructs]("from_csv"), + expression[SchemaOfCsv]("schema_of_csv") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index a585cbed2551..6dd52e0d0549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.math.BigDecimal -import scala.util.control.Exception._ +import scala.util.control.Exception.allCatch import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -private[csv] object CSVInferSchema { +object CSVInferSchema { /** * Similar to the JSON schema inference @@ -43,13 +43,7 @@ private[csv] object CSVInferSchema { val rootTypes: Array[DataType] = tokenRDD.aggregate(startType)(inferRowType(options), mergeRowTypes) - header.zip(rootTypes).map { case (thisHeader, rootType) => - val dType = rootType match { - case _: NullType => StringType - case other => other - } - StructField(thisHeader, dType, nullable = true) - } + toStructFields(rootTypes, header, options) } else { // By default fields are assumed to be StringType header.map(fieldName => StructField(fieldName, StringType, nullable = true)) @@ -58,7 +52,20 @@ private[csv] object CSVInferSchema { StructType(fields) } - private def inferRowType(options: CSVOptions) + def toStructFields( + fieldTypes: Array[DataType], + header: Array[String], + options: CSVOptions): Array[StructField] = { + header.zip(fieldTypes).map { case (thisHeader, rootType) => + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) + } + } + + def inferRowType(options: CSVOptions) (rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = { var i = 0 while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 492a21be6df3..1f39b20bb2f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.nio.charset.StandardCharsets import java.util.{Locale, TimeZone} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala new file mode 100644 index 000000000000..adbe7c402d51 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVUtils.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.csv + +object CSVUtils { + /** + * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). + * This is currently being used in CSV reading path and CSV schema inference. + */ + def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + iter.filter { line => + line.trim.nonEmpty && !line.startsWith(options.comment.toString) + } + } + + /** + * Helper method that converts string representation of a character to actual character. + * It handles some Java escaped strings and throws exception if given string is longer than one + * character. + */ + @throws[IllegalArgumentException] + def toChar(str: String): Char = { + if (str.charAt(0) == '\\') { + str.charAt(1) + match { + case 't' => '\t' + case 'r' => '\r' + case 'b' => '\b' + case 'f' => '\f' + case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options + case '\'' => '\'' + case 'u' if str == """\u0000""" => '\u0000' + case _ => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + } + } else if (str.length == 1) { + str.charAt(0) + } else { + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala similarity index 98% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index 9088d43905e2..4e9508cadc7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import java.io.InputStream import java.math.BigDecimal @@ -28,8 +28,7 @@ import com.univocity.parsers.csv.CsvParser import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils} -import org.apache.spark.sql.execution.datasources.FailureSafeParser +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -264,7 +263,7 @@ class UnivocityParser( } } -private[csv] object UnivocityParser { +private[sql] object UnivocityParser { /** * Parses a stream that contains CSV strings and turns it into an iterator of tokens. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala new file mode 100644 index 000000000000..e5708894f22b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types.{MapType, StringType, StructType} + +object ExprUtils { + + def evalSchemaExpr(exp: Expression): StructType = exp match { + case Literal(s, StringType) => StructType.fromDDL(s.toString) + case e => throw new AnalysisException( + s"Schema should be specified in DDL format as a string literal instead of ${e.sql}") + } + + def convertToMapData(exp: Expression): Map[String, String] = exp match { + case m: CreateMap + if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => + val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] + ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => + key.toString -> value.toString + } + case m: CreateMap => + throw new AnalysisException( + s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") + case _ => + throw new AnalysisException("Must use a map() function for options") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala new file mode 100644 index 000000000000..0e6c2cfeaa99 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import com.fasterxml.jackson.core.JsonFactory +import com.univocity.parsers.csv.CsvParser + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +/** + * Converts a CSV input string to a [[StructType]] with the specified schema. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(csvStr, schema[, options]) - Returns a struct value with the given `csvStr` and `schema`.", + examples = """ + Examples: + > SELECT _FUNC_('1, 0.8', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + """, + since = "3.0.0") +// scalastyle:on line.size.limit +case class CsvToStructs( + schema: StructType, + options: Map[String, String], + child: Expression, + timeZoneId: Option[String] = None) + extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { + + override def nullable: Boolean = true + + // The CSV input data might be missing certain fields. We force the nullability + // of the user-provided schema to avoid data corruptions. + val nullableSchema = schema.asNullable + + // Used in `FunctionRegistry` + def this(child: Expression, schema: Expression, options: Map[String, String]) = + this( + schema = ExprUtils.evalSchemaExpr(schema), + options = options, + child = child, + timeZoneId = None) + + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) + + def this(child: Expression, schema: Expression, options: Expression) = + this( + schema = ExprUtils.evalSchemaExpr(schema), + options = ExprUtils.convertToMapData(options), + child = child, + timeZoneId = None) + + // This converts parsed rows to the desired output by the given schema. + @transient + lazy val converter = (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null + + @transient lazy val parser = { + val parsedOptions = new CSVOptions(options, true, timeZoneId.get) + val mode = parsedOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw new AnalysisException(s"from_csv() doesn't support the ${mode.name} mode. " + + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}.") + } + val actualSchema = + StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)) + val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions) + new FailureSafeParser[String]( + input => Seq(rawParser.parse(input)), + mode, + nullableSchema, + parsedOptions.columnNameOfCorruptRecord, + parsedOptions.multiLine) + } + + override def dataType: DataType = nullableSchema + + override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = + copy(timeZoneId = Option(timeZoneId)) + + override def nullSafeEval(input: Any): Any = { + val csv = input.asInstanceOf[UTF8String].toString + if (csv.trim.isEmpty) return null + try { + converter(parser.parse(csv)) + } catch { + case _: BadRecordException => null + } + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil +} + +@ExpressionDescription( + usage = "_FUNC_(csv[, options]) - Returns schema in the DDL format of CSV string.", + examples = """ + Examples: + > SELECT _FUNC_('1,abc'); + struct<_c0:int,_c1:string> + """, + since = "2.5.0") +case class SchemaOfCsv( + child: Expression, + options: Map[String, String]) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + def this(child: Expression) = this(child, Map.empty[String, String]) + + def this(child: Expression, options: Expression) = this( + child = child, + options = ExprUtils.convertToMapData(options)) + + override def convert(v: UTF8String): UTF8String = { + val parsedOptions = new CSVOptions(options, true, "UTC") + val parser = new CsvParser(parsedOptions.asParserSettings) + val row = parser.parseLine(v.toString) + + if (row != null) { + val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val fieldTypes = CSVInferSchema.inferRowType(parsedOptions)(startType, row) + val st = StructType(CSVInferSchema.toStructFields(fieldTypes, header, parsedOptions)) + UTF8String.fromString(st.catalogString) + } else { + null + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index bd9090a07471..2d8e16f5c8de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -539,7 +539,7 @@ case class JsonToStructs( def this(child: Expression, schema: Expression, options: Expression) = this( schema = JsonExprUtils.evalSchemaExpr(schema), - options = JsonExprUtils.convertToMapData(options), + options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -650,7 +650,7 @@ case class StructsToJson( def this(child: Expression) = this(Map.empty, child, None) def this(child: Expression, options: Expression) = this( - options = JsonExprUtils.convertToMapData(options), + options = ExprUtils.convertToMapData(options), child = child, timeZoneId = None) @@ -771,18 +771,4 @@ object JsonExprUtils { "Schema should be specified in DDL format as a string literal" + s" or output of the schema_of_json function instead of ${e.sql}") } - - def convertToMapData(exp: Expression): Map[String, String] = exp match { - case m: CreateMap - if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => - val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] - ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => - key.toString -> value.toString - } - case m: CreateMap => - throw new AnalysisException( - s"A type of keys and values in map() must be string, but got ${m.dataType.catalogString}") - case _ => - throw new AnalysisException("Must use a map() function for options") - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala similarity index 95% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index 90e81661bae7..fecfff5789a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -15,13 +15,11 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources +package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala index 221e44ce2cff..3217df9aed33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVUtilsSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.csv +package org.apache.spark.sql.catalyst.csv import org.apache.spark.SparkFunSuite diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala new file mode 100644 index 000000000000..aee699178a75 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import java.util.Calendar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.PlanTestBase +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with PlanTestBase { + val badCsv = "\u0000\u0000\u0000A\u0001AAA" + + val gmtId = Option(DateTimeUtils.TimeZoneGMT.getID) + + test("from_csv") { + val csvData = "1" + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData), gmtId), + InternalRow(1) + ) + } + + test("from_csv - invalid data") { + val csvData = "---" + val schema = StructType(StructField("a", DoubleType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(csvData), gmtId), + InternalRow(null)) + + // Default mode is Permissive + checkEvaluation(CsvToStructs(schema, Map.empty, Literal(csvData), gmtId), InternalRow(null)) + } + + test("from_csv null input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId), + null + ) + } + + test("from_csv bad UTF-8") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(badCsv), gmtId), + InternalRow(null)) + } + + test("from_csv with timestamp") { + val schema = StructType(StructField("t", TimestampType) :: Nil) + + val csvData1 = "2016-01-01T00:00:00.123Z" + var c = Calendar.getInstance(DateTimeUtils.TimeZoneGMT) + c.set(2016, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 123) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData1), gmtId), + InternalRow(c.getTimeInMillis * 1000L) + ) + // The result doesn't change because the CSV string includes timezone string ("Z" here), + // which means the string represents the timestamp string in the timezone regardless of + // the timeZoneId parameter. + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal(csvData1), Option("PST")), + InternalRow(c.getTimeInMillis * 1000L) + ) + + val csvData2 = "2016-01-01T00:00:00" + for (tz <- DateTimeTestUtils.ALL_TIMEZONES) { + c = Calendar.getInstance(tz) + c.set(2016, 0, 1, 0, 0, 0) + c.set(Calendar.MILLISECOND, 0) + checkEvaluation( + CsvToStructs( + schema, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"), + Literal(csvData2), + Option(tz.getID)), + InternalRow(c.getTimeInMillis * 1000L) + ) + checkEvaluation( + CsvToStructs( + schema, + Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss", + DateTimeUtils.TIMEZONE_OPTION -> tz.getID), + Literal(csvData2), + gmtId), + InternalRow(c.getTimeInMillis * 1000L) + ) + } + } + + test("from_csv empty input column") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + CsvToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId), + null + ) + } + + test("from_csv missing fields") { + val input = """1,,"foo"""" + val csvSchema = new StructType() + .add("a", LongType, nullable = false) + .add("b", StringType, nullable = false) + .add("c", StringType, nullable = false) + val output = InternalRow(1L, null, UTF8String.fromString("foo")) + val expr = CsvToStructs(csvSchema, Map.empty, Literal.create(input, StringType), gmtId) + checkEvaluation(expr, output) + val schema = expr.dataType + val schemaToCompare = csvSchema.asNullable + assert(schemaToCompare == schema) + } + + test("infer schema of CSV strings") { + checkEvaluation(new SchemaOfCsv(Literal.create("1,abc")), "struct<_c0:int,_c1:string>") + } + + test("infer schema of CSV strings by using options") { + checkEvaluation( + new SchemaOfCsv(Literal.create("1|abc"), + CreateMap(Seq(Literal.create("delimiter"), Literal.create("|")))), + "struct<_c0:int,_c1:string>") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index fe69f252d43e..efb164a54001 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -29,9 +29,11 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index e840ff168250..09e3c5461a2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -34,6 +34,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions, UnivocityParser} +import org.apache.spark.sql.catalyst.csv.CSVUtils.filterCommentAndEmpty import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat import org.apache.spark.sql.types.StructType @@ -248,7 +250,7 @@ object TextInputCSVDataSource extends CSVDataSource { val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions) val tokenRDD = sampled.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val filteredLines = filterCommentAndEmpty(iter, parsedOptions) val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, maybeFirstLine.get, parsedOptions) val parser = new CsvParser(parsedOptions.asParserSettings) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 9aad0bd55e73..7310196d0d17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala index 7ce65fa89b02..786e5123899f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ object CSVUtils { /** @@ -40,16 +40,6 @@ object CSVUtils { } } - /** - * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). - * This is currently being used in CSV reading path and CSV schema inference. - */ - def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = { - iter.filter { line => - line.trim.nonEmpty && !line.startsWith(options.comment.toString) - } - } - /** * Skip the given first line so that only data can remain in a dataset. * This is similar with `dropHeaderLine` below and currently being used in CSV schema inference. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala index 4082a0df8ba7..37d9d9abc868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityGenerator.scala @@ -22,6 +22,7 @@ import java.io.Writer import com.univocity.parsers.csv.CsvWriter import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 76f58371ae26..c7608e2e881f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b67d7a1ca5..11d526a29b69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3811,6 +3811,63 @@ object functions { @scala.annotation.varargs def map_concat(cols: Column*): Column = withExpr { MapConcat(cols.map(_.expr)) } + /** + * Parses a column containing a CSV string into a `StructType` with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the CSV string + * @param options options to control how the CSV is parsed. accepts the same options and the + * CSV data source. + * + * @group collection_funcs + * @since 3.0.0 + */ + def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + CsvToStructs(schema, options, e.expr) + } + + /** + * (Java-specific) Parses a column containing a CSV string into a `StructType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the CSV string + * @param options options to control how the CSV is parsed. accepts the same options and the + * CSV data source. + * + * @group collection_funcs + * @since 3.0.0 + */ + def from_csv(e: Column, schema: String, options: java.util.Map[String, String]): Column = { + withExpr(new CsvToStructs(e.expr, lit(schema).expr, options.asScala.toMap)) + } + + /** + * Parses a column containing a CSV string and infers its schema. + * + * @param e a string column containing CSV data. + * + * @group collection_funcs + * @since 2.5.0 + */ + def schema_of_csv(e: Column): Column = withExpr(new SchemaOfCsv(e.expr)) + + /** + * Parses a column containing a CSV string and infers its schema using options. + * + * @param e a string column containing CSV data. + * @param options options to control how the CSV is parsed. accepts the same options and the + * json data source. See [[DataFrameReader#csv]]. + * @return a column with string literal containing schema in DDL format. + * + * @group collection_funcs + * @since 2.5.0 + */ + def schema_of_csv(e: Column, options: java.util.Map[String, String]): Column = { + withExpr(SchemaOfCsv(e.expr, options.asScala.toMap)) + } + // scalastyle:off line.size.limit // scalastyle:off parameter.number diff --git a/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql new file mode 100644 index 000000000000..b5577396ee33 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/csv-functions.sql @@ -0,0 +1,16 @@ +-- from_csv +describe function from_csv; +describe function extended from_csv; +select from_csv('1', 'a INT'); +select from_csv('1, 3.14', 'a INT, f FLOAT'); +select from_csv('26/08/2015', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select from_csv('1', 1); +select from_csv('1', 'a InvalidType'); +select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')); +select from_csv('1', 'a INT', map('mode', 1)); +select from_csv(); + +-- infer schema of json literal +select schema_of_csv('1,abc'); +select schema_of_csv('1|abc', map('delimiter', '|')); diff --git a/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out new file mode 100644 index 000000000000..c427ef22cd3f --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/csv-functions.sql.out @@ -0,0 +1,120 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 12 + + +-- !query 0 +describe function from_csv +-- !query 0 schema +struct +-- !query 0 output +Class: org.apache.spark.sql.catalyst.expressions.CsvToStructs +Function: from_csv +Usage: from_csv(csvStr, schema[, options]) - Returns a struct value with the given `csvStr` and `schema`. + + +-- !query 1 +describe function extended from_csv +-- !query 1 schema +struct +-- !query 1 output +Class: org.apache.spark.sql.catalyst.expressions.CsvToStructs +Extended Usage: + Examples: + > SELECT from_csv('1, 0.8', 'a INT, b DOUBLE'); + {"a":1, "b":0.8} + + Since: 3.0.0 + +Function: from_csv +Usage: from_csv(csvStr, schema[, options]) - Returns a struct value with the given `csvStr` and `schema`. + + +-- !query 2 +select from_csv('1', 'a INT') +-- !query 2 schema +struct> +-- !query 2 output +{"a":1} + + +-- !query 3 +select from_csv('1, 3.14', 'a INT, f FLOAT') +-- !query 3 schema +struct> +-- !query 3 output +{"a":1,"f":3.14} + + +-- !query 4 +select from_csv('26/08/2015', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) +-- !query 4 schema +struct> +-- !query 4 output +{"time":2015-08-26 00:00:00.0} + + +-- !query 5 +select from_csv('1', 1) +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Schema should be specified in DDL format as a string literal instead of 1;; line 1 pos 7 + + +-- !query 6 +select from_csv('1', 'a InvalidType') +-- !query 6 schema +struct<> +-- !query 6 output +org.apache.spark.sql.AnalysisException + +DataType invalidtype is not supported.(line 1, pos 2) + +== SQL == +a InvalidType +--^^^ +; line 1 pos 7 + + +-- !query 7 +select from_csv('1', 'a INT', named_struct('mode', 'PERMISSIVE')) +-- !query 7 schema +struct<> +-- !query 7 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 8 +select from_csv('1', 'a INT', map('mode', 1)) +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +A type of keys and values in map() must be string, but got map;; line 1 pos 7 + + +-- !query 9 +select from_csv() +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function from_csv. Expected: one of 2 and 3; Found: 0; line 1 pos 7 + + +-- !query 10 +select schema_of_csv('1,abc') +-- !query 10 schema +struct +-- !query 10 output +struct<_c0:int,_c1:string> + + +-- !query 11 +select schema_of_csv('1|abc', map('delimiter', '|')) +-- !query 11 schema +struct +-- !query 11 output +struct<_c0:int,_c1:string> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala new file mode 100644 index 000000000000..e89718124300 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import collection.JavaConverters._ + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ + +class CsvFunctionsSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + val noOptions = Map[String, String]() + + test("from_csv") { + val df = Seq("1").toDS() + val schema = new StructType().add("a", IntegerType) + + checkAnswer( + df.select(from_csv($"value", schema, noOptions)), + Row(Row(1)) :: Nil) + } + + test("from_csv with option") { + val df = Seq("26/08/2015 18:00").toDS() + val schema = new StructType().add("time", TimestampType) + val options = Map("timestampFormat" -> "dd/MM/yyyy HH:mm") + + checkAnswer( + df.select(from_csv($"value", schema, options)), + Row(Row(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))) + } + + test("from_csv missing columns") { + val df = Seq("1").toDS() + val schema = new StructType() + .add("a", IntegerType) + .add("b", IntegerType) + + checkAnswer( + df.select(from_csv($"value", schema, noOptions)), + Row(Row(1, null)) :: Nil) + } + + test("from_csv invalid CSV") { + val df = Seq("???").toDS() + val schema = new StructType().add("a", IntegerType) + + checkAnswer( + df.select(from_csv($"value", schema, noOptions)), + Row(Row(null)) :: Nil) + } + + test("Support from_csv in SQL") { + val df1 = Seq("1").toDS() + checkAnswer( + df1.selectExpr("from_csv(value, 'a INT')"), + Row(Row(1)) :: Nil) + } + + test("infers schemas using options") { + val df = spark.range(1) + .select(schema_of_csv(lit("0.1 1"), Map("sep" -> " ").asJava)) + checkAnswer(df, Seq(Row("struct<_c0:double,_c1:int>"))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 57e36e082653..e8fccfc98fc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions} import org.apache.spark.sql.types._ class CSVInferSchemaSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index 458edb253fb3..6f231142949d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.util.Locale import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.csv.{CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String