Skip to content

Commit 8353000

Browse files
huaxingaosrowen
authored andcommitted
[SPARK-29746][ML] Implement validateInputType in Normalizer/ElementwiseProduct/PolynomialExpansion
### What changes were proposed in this pull request? This PR implements ```validateInput``` in ```ElementwiseProduct```, ```Normalizer``` and ```PolynomialExpansion```. ### Why are the changes needed? ```UnaryTransformer``` has abstract method ```validateInputType``` and call it in ```transformSchema```, but this method is not implemented in ```ElementwiseProduct```, ```Normalizer``` and ```PolynomialExpansion```. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Existing tests Closes #26388 from huaxingao/spark-29746. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 90df858 commit 8353000

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ class ElementwiseProduct @Since("1.4.0") (@Since("1.4.0") override val uid: Stri
7575
}
7676
}
7777

78+
override protected def validateInputType(inputType: DataType): Unit = {
79+
require(inputType.isInstanceOf[VectorUDT],
80+
s"Input type must be ${(new VectorUDT).catalogString} but got ${inputType.catalogString}.")
81+
}
82+
7883
override protected def outputDataType: DataType = new VectorUDT()
7984
}
8085

mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class Normalizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
5959
vector => normalizer.transform(OldVectors.fromML(vector)).asML
6060
}
6161

62+
override protected def validateInputType(inputType: DataType): Unit = {
63+
require(inputType.isInstanceOf[VectorUDT],
64+
s"Input type must be ${(new VectorUDT).catalogString} but got ${inputType.catalogString}.")
65+
}
66+
6267
override protected def outputDataType: DataType = new VectorUDT()
6368
}
6469

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ class PolynomialExpansion @Since("1.4.0") (@Since("1.4.0") override val uid: Str
6868
PolynomialExpansion.expand(v, $(degree))
6969
}
7070

71+
override protected def validateInputType(inputType: DataType): Unit = {
72+
require(inputType.isInstanceOf[VectorUDT],
73+
s"Input type must be ${(new VectorUDT).catalogString} but got ${inputType.catalogString}.")
74+
}
75+
7176
override protected def outputDataType: DataType = new VectorUDT()
7277

7378
@Since("1.4.1")

0 commit comments

Comments
 (0)