diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala new file mode 100644 index 0000000000000..1fe3cfc74c76d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.sql.{Column, DataFrame, Dataset} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * A feature transformer that adds size information to the metadata of a vector column. + * VectorAssembler needs size information for its input columns and cannot be used on streaming + * dataframes without this metadata. + * + */ +@Experimental +@Since("2.3.0") +class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String) + extends Transformer with HasInputCol with HasHandleInvalid with DefaultParamsWritable { + + @Since("2.3.0") + def this() = this(Identifiable.randomUID("vectSizeHint")) + + /** + * The size of Vectors in `inputCol`. + * @group param + */ + @Since("2.3.0") + val size: IntParam = new IntParam( + this, + "size", + "Size of vectors in column.", + {s: Int => s >= 0}) + + /** group getParam */ + @Since("2.3.0") + def getSize: Int = getOrDefault(size) + + /** @group setParam */ + @Since("2.3.0") + def setSize(value: Int): this.type = set(size, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** + * Param for how to handle invalid entries. Invalid vectors include nulls and vectors with the + * wrong size. The options are `skip` (filter out rows with invalid vectors), `error` (throw an + * error) and `optimistic` (do not check the vector size, and keep all rows). `error` by default. + * + * Note: Users should take care when setting this param to `optimistic`. The use of the + * `optimistic` option will prevent the transformer from validating the sizes of vectors in + * `inputCol`. A mismatch between the metadata of a column and its contents could result in + * unexpected behaviour or errors when using that column. + * + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String]( + this, + "handleInvalid", + "How to handle invalid vectors in inputCol. Invalid vectors include nulls and vectors with " + + "the wrong size. The options are `skip` (filter out rows with invalid vectors), `error` " + + "(throw an error) and `optimistic` (do not check the vector size, and keep all rows). " + + "`error` by default.", + ParamValidators.inArray(VectorSizeHint.supportedHandleInvalids)) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, VectorSizeHint.ERROR_INVALID) + + @Since("2.3.0") + override def transform(dataset: Dataset[_]): DataFrame = { + val localInputCol = getInputCol + val localSize = getSize + val localHandleInvalid = getHandleInvalid + + val group = AttributeGroup.fromStructField(dataset.schema(localInputCol)) + val newGroup = validateSchemaAndSize(dataset.schema, group) + if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) { + dataset.toDF() + } else { + val newCol: Column = localHandleInvalid match { + case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol) + case VectorSizeHint.ERROR_INVALID => + val checkVectorSizeUDF = udf { vector: Vector => + if (vector == null) { + throw new SparkException(s"Got null vector in VectorSizeHint, set `handleInvalid` " + + s"to 'skip' to filter invalid rows.") + } + if (vector.size != localSize) { + throw new SparkException(s"VectorSizeHint Expecting a vector of size $localSize but" + + s" got ${vector.size}") + } + vector + }.asNondeterministic() + checkVectorSizeUDF(col(localInputCol)) + case VectorSizeHint.SKIP_INVALID => + val checkVectorSizeUDF = udf { vector: Vector => + if (vector != null && vector.size == localSize) { + vector + } else { + null + } + } + checkVectorSizeUDF(col(localInputCol)) + } + + val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, newGroup.toMetadata())) + if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) { + res.na.drop(Array(localInputCol)) + } else { + res + } + } + } + + /** + * Checks that schema can be updated with new size and returns a new attribute group with + * updated size. + */ + private def validateSchemaAndSize(schema: StructType, group: AttributeGroup): AttributeGroup = { + // This will throw a NoSuchElementException if params are not set. + val localSize = getSize + val localInputCol = getInputCol + + val inputColType = schema(getInputCol).dataType + require( + inputColType.isInstanceOf[VectorUDT], + s"Input column, $getInputCol must be of Vector type, got $inputColType" + ) + group.size match { + case `localSize` => group + case -1 => new AttributeGroup(localInputCol, localSize) + case _ => + val msg = s"Trying to set size of vectors in `$localInputCol` to $localSize but size " + + s"already set to ${group.size}." + throw new IllegalArgumentException(msg) + } + } + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val fieldIndex = schema.fieldIndex(getInputCol) + val fields = schema.fields.clone() + val inputField = fields(fieldIndex) + val group = AttributeGroup.fromStructField(inputField) + val newGroup = validateSchemaAndSize(schema, group) + fields(fieldIndex) = inputField.copy(metadata = newGroup.toMetadata()) + StructType(fields) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): this.type = defaultCopy(extra) +} + +/** :: Experimental :: */ +@Experimental +@Since("2.3.0") +object VectorSizeHint extends DefaultParamsReadable[VectorSizeHint] { + + private[feature] val OPTIMISTIC_INVALID = "optimistic" + private[feature] val ERROR_INVALID = "error" + private[feature] val SKIP_INVALID = "skip" + private[feature] val supportedHandleInvalids: Array[String] = + Array(OPTIMISTIC_INVALID, ERROR_INVALID, SKIP_INVALID) + + @Since("2.3.0") + override def load(path: String): VectorSizeHint = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala new file mode 100644 index 0000000000000..f6c9a76599fae --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.StreamTest + +class VectorSizeHintSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("Test Param Validators") { + intercept[IllegalArgumentException] (new VectorSizeHint().setHandleInvalid("invalidValue")) + intercept[IllegalArgumentException] (new VectorSizeHint().setSize(-3)) + } + + test("Required params must be set before transform.") { + val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue") + + val noSizeTransformer = new VectorSizeHint().setInputCol("vector") + intercept[NoSuchElementException] (noSizeTransformer.transform(data)) + intercept[NoSuchElementException] (noSizeTransformer.transformSchema(data.schema)) + + val noInputColTransformer = new VectorSizeHint().setSize(2) + intercept[NoSuchElementException] (noInputColTransformer.transform(data)) + intercept[NoSuchElementException] (noInputColTransformer.transformSchema(data.schema)) + } + + test("Adding size to column of vectors.") { + + val size = 3 + val vectorColName = "vector" + val denseVector = Vectors.dense(1, 2, 3) + val sparseVector = Vectors.sparse(size, Array(), Array()) + + val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply) + val dataFrame = data.toDF(vectorColName) + assert( + AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == -1, + s"This test requires that column '$vectorColName' not have size metadata.") + + for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) { + val transformer = new VectorSizeHint() + .setInputCol(vectorColName) + .setSize(size) + .setHandleInvalid(handleInvalid) + val withSize = transformer.transform(dataFrame) + assert( + AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size, + "Transformer did not add expected size data.") + val numRows = withSize.collect().length + assert(numRows === data.length, s"Expecting ${data.length} rows, got $numRows.") + } + } + + test("Size hint preserves attributes.") { + + val size = 3 + val vectorColName = "vector" + val data = Seq((1, 2, 3), (2, 3, 3)) + val dataFrame = data.toDF("x", "y", "z") + + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z")) + .setOutputCol(vectorColName) + val dataFrameWithMetadata = assembler.transform(dataFrame) + val group = AttributeGroup.fromStructField(dataFrameWithMetadata.schema(vectorColName)) + + for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) { + val transformer = new VectorSizeHint() + .setInputCol(vectorColName) + .setSize(size) + .setHandleInvalid(handleInvalid) + val withSize = transformer.transform(dataFrameWithMetadata) + + val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName)) + assert(newGroup.size === size, "Column has incorrect size metadata.") + assert( + newGroup.attributes.get === group.attributes.get, + "VectorSizeHint did not preserve attributes.") + withSize.collect + } + } + + test("Size mismatch between current and target size raises an error.") { + val size = 4 + val vectorColName = "vector" + val data = Seq((1, 2, 3), (2, 3, 3)) + val dataFrame = data.toDF("x", "y", "z") + + val assembler = new VectorAssembler() + .setInputCols(Array("x", "y", "z")) + .setOutputCol(vectorColName) + val dataFrameWithMetadata = assembler.transform(dataFrame) + + for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) { + val transformer = new VectorSizeHint() + .setInputCol(vectorColName) + .setSize(size) + .setHandleInvalid(handleInvalid) + intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata)) + } + } + + test("Handle invalid does the right thing.") { + + val vector = Vectors.dense(1, 2, 3) + val short = Vectors.dense(2) + val dataWithNull = Seq(vector, null).map(Tuple1.apply).toDF("vector") + val dataWithShort = Seq(vector, short).map(Tuple1.apply).toDF("vector") + + val sizeHint = new VectorSizeHint() + .setInputCol("vector") + .setHandleInvalid("error") + .setSize(3) + + intercept[SparkException](sizeHint.transform(dataWithNull).collect()) + intercept[SparkException](sizeHint.transform(dataWithShort).collect()) + + sizeHint.setHandleInvalid("skip") + assert(sizeHint.transform(dataWithNull).count() === 1) + assert(sizeHint.transform(dataWithShort).count() === 1) + + sizeHint.setHandleInvalid("optimistic") + assert(sizeHint.transform(dataWithNull).count() === 2) + assert(sizeHint.transform(dataWithShort).count() === 2) + } + + test("read/write") { + val sizeHint = new VectorSizeHint() + .setInputCol("myInputCol") + .setSize(11) + .setHandleInvalid("skip") + testDefaultReadWrite(sizeHint) + } +} + +class VectorSizeHintStreamingSuite extends StreamTest { + + import testImplicits._ + + test("Test assemble vectors with size hint in streaming.") { + val a = Vectors.dense(0, 1, 2) + val b = Vectors.sparse(4, Array(0, 3), Array(3, 6)) + + val stream = MemoryStream[(Vector, Vector)] + val streamingDF = stream.toDS.toDF("a", "b") + val sizeHintA = new VectorSizeHint() + .setSize(3) + .setInputCol("a") + val sizeHintB = new VectorSizeHint() + .setSize(4) + .setInputCol("b") + val vectorAssembler = new VectorAssembler() + .setInputCols(Array("a", "b")) + .setOutputCol("assembled") + val pipeline = new Pipeline().setStages(Array(sizeHintA, sizeHintB, vectorAssembler)) + val output = pipeline.fit(streamingDF).transform(streamingDF).select("assembled") + + val expected = Vectors.dense(0, 1, 2, 3, 0, 0, 6) + + testStream (output) ( + AddData(stream, (a, b), (a, b)), + CheckAnswer(Tuple1(expected), Tuple1(expected)) + ) + } +} diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 440684d3edfa6..95eca76fa9888 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1906,9 +1906,9 @@ def toPandas(self): if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: from pyspark.sql.types import _check_dataframe_localize_timestamps - from pyspark.sql.utils import _require_minimum_pyarrow_version + from pyspark.sql.utils import require_minimum_pyarrow_version import pyarrow - _require_minimum_pyarrow_version() + require_minimum_pyarrow_version() tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 86db16eca7889..6e5eec48e8aca 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -493,15 +493,14 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone): data types will be used to coerce the data in Pandas to Arrow conversion. """ from pyspark.serializers import ArrowSerializer, _create_batch - from pyspark.sql.types import from_arrow_schema, to_arrow_type, \ - _old_pandas_exception_message, TimestampType - from pyspark.sql.utils import _require_minimum_pyarrow_version - try: - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType + from pyspark.sql.utils import require_minimum_pandas_version, \ + require_minimum_pyarrow_version + + require_minimum_pandas_version() + require_minimum_pyarrow_version() - _require_minimum_pyarrow_version() + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6fdfda1cc831b..b977160af566d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -53,7 +53,8 @@ try: import pandas try: - import pandas.api + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() _have_pandas = True except: _have_old_pandas = True @@ -2600,7 +2601,7 @@ def test_to_pandas(self): @unittest.skipIf(not _have_old_pandas, "Old Pandas not installed") def test_to_pandas_old(self): with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self._to_pandas() @unittest.skipIf(not _have_pandas, "Pandas not installed") @@ -2643,7 +2644,7 @@ def test_create_dataframe_from_old_pandas(self): pdf = pd.DataFrame({"ts": [datetime(2017, 10, 31, 1, 1, 1)], "d": [pd.Timestamp.now().date()]}) with QuietTest(self.sc): - with self.assertRaisesRegexp(ImportError, 'Pandas \(.*\) must be installed'): + with self.assertRaisesRegexp(ImportError, 'Pandas >= .* must be installed'): self.spark.createDataFrame(pdf) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 46d9a417414b5..063264a89379c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1678,13 +1678,6 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _old_pandas_exception_message(e): - """ Create an error message for importing old Pandas. - """ - msg = "note: Pandas (>=0.19.2) must be installed and available on calling Python process" - return "%s\n%s" % (_exception_message(e), msg) - - def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1693,10 +1686,10 @@ def _check_dataframe_localize_timestamps(pdf, timezone): :param timezone: the timezone to convert. if None then use local timezone :return pandas.DataFrame where any timezone aware columns have been converted to tz-naive """ - try: - from pandas.api.types import is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype tz = timezone or 'tzlocal()' for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? @@ -1714,10 +1707,10 @@ def _check_series_convert_timestamps_internal(s, timezone): :param timezone: the timezone to convert. if None then use local timezone :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone """ - try: - from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): tz = timezone or 'tzlocal()' @@ -1737,11 +1730,11 @@ def _check_series_convert_timestamps_localize(s, from_timezone, to_timezone): :param to_timezone: the timezone to convert to. if None then use local timezone :return pandas.Series where if it is a timestamp, has been converted to tz-naive """ - try: - import pandas as pd - from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype - except ImportError as e: - raise ImportError(_old_pandas_exception_message(e)) + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + import pandas as pd + from pandas.api.types import is_datetime64tz_dtype, is_datetime64_dtype from_tz = from_timezone or 'tzlocal()' to_tz = to_timezone or 'tzlocal()' # TODO: handle nested timestamps, such as ArrayType(TimestampType())? diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 50c87ba1ac882..123138117fdc3 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -37,9 +37,9 @@ def _create_udf(f, returnType, evalType): if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \ evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF: import inspect - from pyspark.sql.utils import _require_minimum_pyarrow_version + from pyspark.sql.utils import require_minimum_pyarrow_version - _require_minimum_pyarrow_version() + require_minimum_pyarrow_version() argspec = inspect.getargspec(f) if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF and len(argspec.args) == 0 and \ diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index cc7dabb64b3ec..fb7d42a35d8f4 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -112,7 +112,16 @@ def toJArray(gateway, jtype, arr): return jarr -def _require_minimum_pyarrow_version(): +def require_minimum_pandas_version(): + """ Raise ImportError if minimum version of Pandas is not installed + """ + from distutils.version import LooseVersion + import pandas + if LooseVersion(pandas.__version__) < LooseVersion('0.19.2'): + raise ImportError("Pandas >= 0.19.2 must be installed on calling Python process") + + +def require_minimum_pyarrow_version(): """ Raise ImportError if minimum version of pyarrow is not installed """ from distutils.version import LooseVersion