From a9f894df3671bb8df2f342de1820dab3185598f3 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 10 Nov 2015 11:40:22 +0800 Subject: [PATCH 1/3] ml.feature.Word2Vec.transform() function very slow, we shouldn't read broadcast every sentence --- mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 9edab3af913c..5822e56c1c2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -198,12 +198,12 @@ class Word2VecModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val model = bWordVectors.value.getVectors val word2Vec = udf { sentence: Seq[String] => if (sentence.size == 0) { Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) } else { val cum = Vectors.zeros($(vectorSize)) - val model = bWordVectors.value.getVectors for (word <- sentence) { if (model.contains(word)) { axpy(1.0, bWordVectors.value.transform(word), cum) From a2b2835d793ebde1fd41b48a9c19021855545252 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 11 Nov 2015 09:43:06 +0800 Subject: [PATCH 2/3] modify wordVectors to transient. --- mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 5822e56c1c2e..12a929f08414 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -148,7 +148,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] @Experimental class Word2VecModel private[ml] ( override val uid: String, - wordVectors: feature.Word2VecModel) + @transient wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { From fec64602ed98f85708bf738d09c72b30b1010702 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 10 Nov 2015 19:15:54 -0800 Subject: [PATCH 3/3] improve implementation to avoid mapping per record --- .../apache/spark/ml/feature/Word2Vec.scala | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 12a929f08414..5c64cb09d594 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,18 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types._ /** @@ -151,7 +149,6 @@ class Word2VecModel private[ml] ( @transient wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { - /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and * and the vector the DenseVector that it is mapped to. @@ -197,22 +194,23 @@ class Word2VecModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) - val model = bWordVectors.value.getVectors + val vectors = wordVectors.getVectors + .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) + .map(identity) // mapValues doesn't return a serializable map (SI-7005) + val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors) + val d = $(vectorSize) val word2Vec = udf { sentence: Seq[String] => if (sentence.size == 0) { - Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) + Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) } else { - val cum = Vectors.zeros($(vectorSize)) - for (word <- sentence) { - if (model.contains(word)) { - axpy(1.0, bWordVectors.value.transform(word), cum) - } else { - // pass words which not belong to model + val sum = Vectors.zeros(d) + sentence.foreach { word => + bVectors.value.get(word).foreach { v => + BLAS.axpy(1.0, v, sum) } } - scal(1.0 / sentence.size, cum) - cum + BLAS.scal(1.0 / sentence.size, sum) + sum } } dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))