Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@

package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS._
import org.apache.spark.sql.DataFrame
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
@Experimental
class Word2VecModel private[ml] (
override val uid: String,
wordVectors: feature.Word2VecModel)
@transient wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {


/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
Expand Down Expand Up @@ -197,22 +194,23 @@ class Word2VecModel private[ml] (
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
val vectors = wordVectors.getVectors
.mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
.map(identity) // mapValues doesn't return a serializable map (SI-7005)
val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
val d = $(vectorSize)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
} else {
val cum = Vectors.zeros($(vectorSize))
val model = bWordVectors.value.getVectors
for (word <- sentence) {
if (model.contains(word)) {
axpy(1.0, bWordVectors.value.transform(word), cum)
} else {
// pass words which not belong to model
val sum = Vectors.zeros(d)
sentence.foreach { word =>
bVectors.value.get(word).foreach { v =>
BLAS.axpy(1.0, v, sum)
}
}
scal(1.0 / sentence.size, cum)
cum
BLAS.scal(1.0 / sentence.size, sum)
sum
}
}
dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))
Expand Down