Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,25 @@ class BucketedRandomProjectionLSHModel private[ml](
}

@Since("2.1.0")
override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
override protected[ml] def hashDistance(x: Array[Vector], y: Array[Vector]): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
x.zip(y).map(vectorPair => Vectors.sqdist(vectorPair._1, vectorPair._2)).min
var distance = Double.MaxValue
var i = 0
while (i < x.length) {
val vx = x(i).toArray
val vy = y(i).toArray
var j = 0
var d = 0.0
while (j < vx.length && d < distance) {
val diff = vx(j) - vy(j)
d += diff * diff
j += 1
}
if (d == 0) return 0.0
if (d < distance) distance = d
i += 1
}
distance
}

@Since("2.1.0")
Expand Down
14 changes: 7 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
* @param y Another hash vector.
* @return The distance between hash vectors x and y.
*/
protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double
protected[ml] def hashDistance(x: Array[Vector], y: Array[Vector]): Double

override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -116,25 +116,25 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
// Get Hash Value of the key
val keyHash = hashFunction(key)
val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) {
val modelDataset = if (!dataset.columns.contains($(outputCol))) {
transform(dataset)
} else {
dataset.toDF()
}

val modelSubset = if (singleProbe) {
def sameBucket(x: Seq[Vector], y: Seq[Vector]): Boolean = {
x.zip(y).exists(tuple => tuple._1 == tuple._2)
def sameBucket(x: Array[Vector], y: Array[Vector]): Boolean = {
x.iterator.zip(y.iterator).exists(tuple => tuple._1 == tuple._2)
}

// In the origin dataset, find the hash value that hash the same bucket with the key
val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => sameBucket(x, keyHash))
val sameBucketWithKeyUDF = udf((x: Array[Vector]) => sameBucket(x, keyHash))

modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
} else {
// In the origin dataset, find the hash value that is closest to the key
// Limit the use of hashDist since it's controversial
val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash))
val hashDistUDF = udf((x: Array[Vector]) => hashDistance(x, keyHash))
val hashDistCol = hashDistUDF(col($(outputCol)))
val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol)

Expand Down Expand Up @@ -223,7 +223,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
inputName: String,
explodeCols: Seq[String]): Dataset[_] = {
require(explodeCols.size == 2, "explodeCols must be two strings.")
val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) {
val modelDataset = if (!dataset.columns.contains($(outputCol))) {
transform(dataset)
} else {
dataset.toDF()
Expand Down
21 changes: 17 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,25 @@ class MinHashLSHModel private[ml](
}

@Since("2.1.0")
override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
override protected[ml] def hashDistance(x: Array[Vector], y: Array[Vector]): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
// TODO: This hashDistance function requires more discussion in SPARK-18454
x.iterator.zip(y.iterator).map(vectorPair =>
vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
).min
var distance = Int.MaxValue
var i = 0
while (i < x.length) {
val vx = x(i).toArray
val vy = y(i).toArray
var j = 0
var d = 0
while (j < vx.length && d < distance) {
if (vx(j) != vy(j)) d += 1
j += 1
}
if (d == 0) return 0.0
if (d < distance) distance = d
i += 1
}
distance
}

@Since("2.1.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ private[ml] object LSHTest {
// Perform a cross join and label each pair of same_bucket and distance
val pairs = transformedData.as("a").crossJoin(transformedData.as("b"))
val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y))
val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0)
val sameBucket = udf((x: Array[Vector], y: Array[Vector]) => model.hashDistance(x, y) == 0.0)
val result = pairs
.withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol")))
.withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")))
Expand Down