Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml](
private[ml] val randUnitVectors: Array[Vector])
extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {

/** @group setParam */
@Since("2.4.0")
override def setInputCol(value: String): this.type = super.set(inputCol, value)

/** @group setParam */
@Since("2.4.0")
override def setOutputCol(value: String): this.type = super.set(outputCol, value)

@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
key: Vector => {
Expand Down
6 changes: 6 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
extends Model[T] with LSHParams with MLWritable {
self: T =>

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* The hash function of LSH, mapping an input feature vector to multiple hash vectors.
* @return The mapping of LSH function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ class MinHashLSHModel private[ml](
private[ml] val randCoefficients: Array[(Int, Int)])
extends LSHModel[MinHashLSHModel] {

/** @group setParam */
@Since("2.4.0")
override def setInputCol(value: String): this.type = super.set(inputCol, value)

/** @group setParam */
@Since("2.4.0")
override def setOutputCol(value: String): this.type = super.set(outputCol, value)

@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
elems: Vector => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest
ParamsSuite.checkParams(model)
}

test("setters") {
val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0)))
.setInputCol("testkeys")
.setOutputCol("testvalues")
assert(model.getInputCol === "testkeys")
assert(model.getOutputCol === "testvalues")
}

test("BucketedRandomProjectionLSH: default params") {
val brp = new BucketedRandomProjectionLSH
assert(brp.getNumHashTables === 1.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
ParamsSuite.checkParams(model)
}

test("setters") {
val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
.setInputCol("testkeys")
.setOutputCol("testvalues")
assert(model.getInputCol === "testkeys")
assert(model.getOutputCol === "testvalues")
}

test("MinHashLSH: default params") {
val rp = new MinHashLSH
assert(rp.getNumHashTables === 1.0)
Expand Down