Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
DataFrame-based machine learning APIs to let users quickly assemble and configure practical
machine learning pipelines.
"""
from pyspark.ml.base import Estimator, Model, Transformer
from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer
from pyspark.ml.pipeline import Pipeline, PipelineModel

__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
56 changes: 56 additions & 0 deletions python/pyspark/ml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

from abc import ABCMeta, abstractmethod

import copy

from pyspark import since
from pyspark.ml.param import Params
from pyspark.ml.param.shared import *
from pyspark.ml.common import inherit_doc
from pyspark.sql.functions import udf
from pyspark.sql.types import StructField, StructType, DoubleType


@inherit_doc
Expand Down Expand Up @@ -116,3 +121,54 @@ class Model(Transformer):
"""

__metaclass__ = ABCMeta


@inherit_doc
class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
"""
Abstract class for transformers that take one input column, apply transformation,
and output the result as a new column.

.. versionadded:: 2.3.0
"""

@abstractmethod
def createTransformFunc(self):
"""
Creates the transform function using the given param map. The input param map already takes
account of the embedded param map. So the param values should be determined
solely by the input param map.
"""
raise NotImplementedError()

@abstractmethod
def outputDataType(self):
"""
Returns the data type of the output column.
"""
raise NotImplementedError()

@abstractmethod
def validateInputType(self, inputType):
"""
Validates the input type. Throw an exception if it is invalid.
"""
raise NotImplementedError()

def transformSchema(self, schema):
inputType = schema[self.getInputCol()].dataType
self.validateInputType(inputType)
if self.getOutputCol() in schema.names:
raise ValueError("Output column %s already exists." % self.getOutputCol())
outputFields = copy.copy(schema.fields)
outputFields.append(StructField(self.getOutputCol(),
self.outputDataType(),
nullable=False))
return StructType(outputFields)

def _transform(self, dataset):
self.transformSchema(dataset.schema)
transformUDF = udf(self.createTransformFunc(), self.outputDataType())
transformedDataset = dataset.withColumn(self.getOutputCol(),
transformUDF(dataset[self.getInputCol()]))
return transformedDataset
62 changes: 61 additions & 1 deletion python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import inspect

from pyspark import keyword_only, SparkContext
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer
from pyspark.ml.classification import *
from pyspark.ml.clustering import *
from pyspark.ml.common import _java2py, _py2java
Expand All @@ -66,6 +66,7 @@
from pyspark.serializers import PickleSerializer
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql.functions import rand
from pyspark.sql.types import DoubleType, IntegerType
from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase

Expand Down Expand Up @@ -121,6 +122,36 @@ def _transform(self, dataset):
return dataset


class MockUnaryTransformer(UnaryTransformer):

shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
"data in a DataFrame",
typeConverter=TypeConverters.toFloat)

def __init__(self, shiftVal=1):
super(MockUnaryTransformer, self).__init__()
self._setDefault(shift=1)
self._set(shift=shiftVal)

def getShift(self):
return self.getOrDefault(self.shift)

def setShift(self, shift):
self._set(shift=shift)

def createTransformFunc(self):
shiftVal = self.getShift()
return lambda x: x + shiftVal

def outputDataType(self):
return DoubleType()

def validateInputType(self, inputType):
if inputType != DoubleType():
raise TypeError("Bad input type: {}. ".format(inputType) +
"Requires Integer.")


class MockEstimator(Estimator, HasFake):

def __init__(self):
Expand Down Expand Up @@ -1957,6 +1988,35 @@ def test_chisquaretest(self):
self.assertTrue(all(field in fieldNames for field in expectedFields))


class UnaryTransformerTests(SparkSessionTestCase):

def test_unary_transformer_validate_input_type(self):
shiftVal = 3
transformer = MockUnaryTransformer(shiftVal=shiftVal)\
.setInputCol("input").setOutputCol("output")

# should not raise any errors
transformer.validateInputType(DoubleType())

with self.assertRaises(TypeError):
# passing the wrong input type should raise an error
transformer.validateInputType(IntegerType())

def test_unary_transformer_transform(self):
shiftVal = 3
transformer = MockUnaryTransformer(shiftVal=shiftVal)\
.setInputCol("input").setOutputCol("output")

df = self.spark.range(0, 10).toDF('input')
df = df.withColumn("input", df.input.cast(dataType="double"))

transformed_df = transformer.transform(df)
results = transformed_df.select("input", "output").collect()

for res in results:
self.assertEqual(res.input + shiftVal, res.output)


if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
Expand Down