diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala index 951da518c8c8e..475fc5b7f8ccf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala @@ -95,9 +95,25 @@ class BucketedRandomProjectionLSHModel private[ml]( } @Since("2.1.0") - override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = { + override protected[ml] def hashDistance(x: Array[Vector], y: Array[Vector]): Double = { // Since it's generated by hashing, it will be a pair of dense vectors. - x.zip(y).map(vectorPair => Vectors.sqdist(vectorPair._1, vectorPair._2)).min + var distance = Double.MaxValue + var i = 0 + while (i < x.length) { + val vx = x(i).toArray + val vy = y(i).toArray + var j = 0 + var d = 0.0 + while (j < vx.length && d < distance) { + val diff = vx(j) - vy(j) + d += diff * diff + j += 1 + } + if (d == 0) return 0.0 + if (d < distance) distance = d + i += 1 + } + distance } @Since("2.1.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 9d647f3e514c5..c3304047fce90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -94,7 +94,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] * @param y Another hash vector. * @return The distance between hash vectors x and y. */ - protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double + protected[ml] def hashDistance(x: Array[Vector], y: Array[Vector]): Double override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) @@ -116,25 +116,25 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1") // Get Hash Value of the key val keyHash = hashFunction(key) - val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { + val modelDataset = if (!dataset.columns.contains($(outputCol))) { transform(dataset) } else { dataset.toDF() } val modelSubset = if (singleProbe) { - def sameBucket(x: Seq[Vector], y: Seq[Vector]): Boolean = { - x.zip(y).exists(tuple => tuple._1 == tuple._2) + def sameBucket(x: Array[Vector], y: Array[Vector]): Boolean = { + x.iterator.zip(y.iterator).exists(tuple => tuple._1 == tuple._2) } // In the origin dataset, find the hash value that hash the same bucket with the key - val sameBucketWithKeyUDF = udf((x: Seq[Vector]) => sameBucket(x, keyHash)) + val sameBucketWithKeyUDF = udf((x: Array[Vector]) => sameBucket(x, keyHash)) modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol)))) } else { // In the origin dataset, find the hash value that is closest to the key // Limit the use of hashDist since it's controversial - val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash)) + val hashDistUDF = udf((x: Array[Vector]) => hashDistance(x, keyHash)) val hashDistCol = hashDistUDF(col($(outputCol))) val modelDatasetWithDist = modelDataset.withColumn(distCol, hashDistCol) @@ -223,7 +223,7 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]] inputName: String, explodeCols: Seq[String]): Dataset[_] = { require(explodeCols.size == 2, "explodeCols must be two strings.") - val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) { + val modelDataset = if (!dataset.columns.contains($(outputCol))) { transform(dataset) } else { dataset.toDF() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index 12cae13174379..d189edcb4e558 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -106,12 +106,25 @@ class MinHashLSHModel private[ml]( } @Since("2.1.0") - override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = { + override protected[ml] def hashDistance(x: Array[Vector], y: Array[Vector]): Double = { // Since it's generated by hashing, it will be a pair of dense vectors. // TODO: This hashDistance function requires more discussion in SPARK-18454 - x.iterator.zip(y.iterator).map(vectorPair => - vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2) - ).min + var distance = Int.MaxValue + var i = 0 + while (i < x.length) { + val vx = x(i).toArray + val vy = y(i).toArray + var j = 0 + var d = 0 + while (j < vx.length && d < distance) { + if (vx(j) != vy(j)) d += 1 + j += 1 + } + if (d == 0) return 0.0 + if (d < distance) distance = d + i += 1 + } + distance } @Since("2.1.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index 55dade28920ed..2815adb75adf3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -78,7 +78,7 @@ private[ml] object LSHTest { // Perform a cross join and label each pair of same_bucket and distance val pairs = transformedData.as("a").crossJoin(transformedData.as("b")) val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y)) - val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0) + val sameBucket = udf((x: Array[Vector], y: Array[Vector]) => model.hashDistance(x, y) == 0.0) val result = pairs .withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol"))) .withColumn("distance", distUDF(col(s"a.$inputCol"), col(s"b.$inputCol")))