@@ -28,6 +28,7 @@ import org.apache.spark.annotation.Experimental
2828import org .apache .spark .mllib .linalg .{Vector , Vectors }
2929import org .apache .spark .mllib .rdd .RDDFunctions ._
3030import org .apache .spark .rdd ._
31+ import org .apache .spark .util .Utils
3132import org .apache .spark .util .random .XORShiftRandom
3233
3334/**
@@ -58,29 +59,63 @@ private case class VocabWord(
5859 * Efficient Estimation of Word Representations in Vector Space
5960 * and
6061 * Distributed Representations of Words and Phrases and their Compositionality.
61- * @param size vector dimension
62- * @param startingAlpha initial learning rate
63- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
64- * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
6562 */
6663@ Experimental
67- class Word2Vec (
68- val size : Int ,
69- val startingAlpha : Double ,
70- val parallelism : Int ,
71- val numIterations : Int ) extends Serializable with Logging {
64+ class Word2Vec extends Serializable with Logging {
65+
66+ private var vectorSize = 100
67+ private var startingAlpha = 0.025
68+ private var numPartitions = 1
69+ private var numIterations = 1
70+ private var seed = Utils .random.nextLong()
71+
72+ /**
73+ * Sets vector size (default: 100).
74+ */
75+ def setVectorSize (vectorSize : Int ): this .type = {
76+ this .vectorSize = vectorSize
77+ this
78+ }
79+
80+ /**
81+ * Sets initial learning rate (default: 0.025).
82+ */
83+ def setLearningRate (learningRate : Double ): this .type = {
84+ this .startingAlpha = learningRate
85+ this
86+ }
7287
7388 /**
74- * Word2Vec with a single thread .
89+ * Sets number of partitions (default: 1). Use a small number for accuracy .
7590 */
76- def this (size : Int , startingAlpha : Int ) = this (size, startingAlpha, 1 , 1 )
91+ def setNumPartitions (numPartitions : Int ): this .type = {
92+ require(numPartitions > 0 , s " numPartitions must be greater than 0 but got $numPartitions" )
93+ this .numPartitions = numPartitions
94+ this
95+ }
96+
97+ /**
98+ * Sets number of iterations (default: 1), which should be smaller than or equal to number of
99+ * partitions.
100+ */
101+ def setNumIterations (numIterations : Int ): this .type = {
102+ this .numIterations = numIterations
103+ this
104+ }
105+
106+ /**
107+ * Sets random seed (default: a random long integer).
108+ */
109+ def setSeed (seed : Long ): this .type = {
110+ this .seed = seed
111+ this
112+ }
77113
78114 private val EXP_TABLE_SIZE = 1000
79115 private val MAX_EXP = 6
80116 private val MAX_CODE_LENGTH = 40
81117 private val MAX_SENTENCE_LENGTH = 1000
82- private val layer1Size = size
83- private val modelPartitionNum = 100
118+ private val layer1Size = vectorSize
84119
85120 /** context words from [-window, window] */
86121 private val window = 5
@@ -245,8 +280,7 @@ class Word2Vec(
245280 }
246281 }
247282
248- val newSentences = sentences.repartition(parallelism).cache()
249- val seed = 5875483L
283+ val newSentences = sentences.repartition(numPartitions).cache()
250284 val initRandom = new XORShiftRandom (seed)
251285 var syn0Global =
252286 Array .fill[Float ](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f ) / layer1Size)
@@ -263,7 +297,7 @@ class Word2Vec(
263297 lwc = wordCount
264298 // TODO: discount by iteration?
265299 alpha =
266- startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1 ))
300+ startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1 ))
267301 if (alpha < startingAlpha * 0.0001 ) alpha = startingAlpha * 0.0001
268302 logInfo(" wordCount = " + wordCount + " , alpha = " + alpha)
269303 }
@@ -404,23 +438,3 @@ class Word2VecModel private[mllib] (
404438 .toArray
405439 }
406440}
407-
408- object Word2Vec {
409- /**
410- * Train Word2Vec model
411- * @param input RDD of words
412- * @param size vector dimension
413- * @param startingAlpha initial learning rate
414- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
415- * @param numIterations number of iterations, should be smaller than or equal to parallelism
416- * @return Word2Vec model
417- */
418- def train [S <: Iterable [String ]](
419- input : RDD [S ],
420- size : Int ,
421- startingAlpha : Double ,
422- parallelism : Int = 1 ,
423- numIterations: Int = 1 ): Word2VecModel = {
424- new Word2Vec (size,startingAlpha, parallelism, numIterations).fit[S ](input)
425- }
426- }
0 commit comments