-
Notifications
You must be signed in to change notification settings - Fork 29k
[MLlib] [SPARK-2510]Word2Vec: Distributed Representation of Words #1719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
8d6befe
0aafb1b
e4a04d3
57dc50d
2e92b59
720b5a3
6bcc8be
7efbb6f
1a8fb41
e93e726
384c771
c14da41
26a948d
e248441
2ba9483
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,375 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * Add a comment to this line | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.mllib.feature | ||
|
|
||
| import scala.util.{Random => Random} | ||
| import scala.collection.mutable.ArrayBuffer | ||
| import scala.collection.mutable | ||
|
|
||
| import com.github.fommil.netlib.BLAS.{getInstance => blas} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Force using |
||
|
|
||
| import org.apache.spark._ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto: if not many classes are used, please list them |
||
| import org.apache.spark.rdd._ | ||
| import org.apache.spark.SparkContext._ | ||
| import org.apache.spark.mllib.linalg.Vector | ||
| import org.apache.spark.HashPartitioner | ||
|
|
||
| /** | ||
| * Entry in vocabulary | ||
| */ | ||
| private case class VocabWord( | ||
| var word: String, | ||
| var cn: Int, | ||
| var point: Array[Int], | ||
| var code: Array[Int], | ||
| var codeLen:Int | ||
| ) | ||
|
|
||
| /** | ||
| * Vector representation of word | ||
| */ | ||
| class Word2Vec( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need more docs here, for example, link to the C implementation and the original papers for word2vec.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and briefly explain what it does.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, this is definitely an experimental feature. Please add |
||
| val size: Int, | ||
| val startingAlpha: Double, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is word2vec sensitive to alpha? If not, we should try to expose less parameters to users.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. word2vec is sensitive to alpha. Larger alpha may generate meaningless result
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can suggest a reasonable default value in the doc. |
||
| val window: Int, | ||
| val minCount: Int) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For each argument, we need to have a coressponding JavaDoc. For example: |
||
| extends Serializable with Logging { | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please leave a note that the variable/method names are to match the original C implementation. Then people understand why, e.g., we map |
||
| private val EXP_TABLE_SIZE = 1000 | ||
| private val MAX_EXP = 6 | ||
| private val MAX_CODE_LENGTH = 40 | ||
| private val MAX_SENTENCE_LENGTH = 1000 | ||
| private val layer1Size = size | ||
| private val modelPartitionNum = 100 | ||
|
|
||
| private var trainWordsCount = 0 | ||
| private var vocabSize = 0 | ||
| private var vocab: Array[VocabWord] = null | ||
| private var vocabHash = mutable.HashMap.empty[String, Int] | ||
| private var alpha = startingAlpha | ||
|
|
||
| private def learnVocab(dataset: RDD[String]) { | ||
| vocab = dataset.flatMap(line => line.split(" ")) | ||
| .map(w => (w, 1)) | ||
| .reduceByKey(_ + _) | ||
| .map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure this line doesn't exceed 100 characters |
||
| .filter(_.cn >= minCount) | ||
| .collect() | ||
| .sortWith((a, b)=> a.cn > b.cn) | ||
|
|
||
| vocabSize = vocab.length | ||
| var a = 0 | ||
| while (a < vocabSize) { | ||
| vocabHash += vocab(a).word -> a | ||
| trainWordsCount += vocab(a).cn | ||
| a += 1 | ||
| } | ||
| logInfo("trainWordsCount = " + trainWordsCount) | ||
| } | ||
|
|
||
| private def createExpTable(): Array[Double] = { | ||
| val expTable = new Array[Double](EXP_TABLE_SIZE) | ||
| var i = 0 | ||
| while (i < EXP_TABLE_SIZE) { | ||
| val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) | ||
| expTable(i) = tmp / (tmp + 1) | ||
| i += 1 | ||
| } | ||
| expTable | ||
| } | ||
|
|
||
| private def createBinaryTree() { | ||
| val count = new Array[Long](vocabSize * 2 + 1) | ||
| val binary = new Array[Int](vocabSize * 2 + 1) | ||
| val parentNode = new Array[Int](vocabSize * 2 + 1) | ||
| val code = new Array[Int](MAX_CODE_LENGTH) | ||
| val point = new Array[Int](MAX_CODE_LENGTH) | ||
| var a = 0 | ||
| while (a < vocabSize) { | ||
| count(a) = vocab(a).cn | ||
| a += 1 | ||
| } | ||
| while (a < 2 * vocabSize) { | ||
| count(a) = 1e9.toInt | ||
| a += 1 | ||
| } | ||
| var pos1 = vocabSize - 1 | ||
| var pos2 = vocabSize | ||
|
|
||
| var min1i = 0 | ||
| var min2i = 0 | ||
|
|
||
| a = 0 | ||
| while (a < vocabSize - 1) { | ||
| if (pos1 >= 0) { | ||
| if (count(pos1) < count(pos2)) { | ||
| min1i = pos1 | ||
| pos1 -= 1 | ||
| } else { | ||
| min1i = pos2 | ||
| pos2 += 1 | ||
| } | ||
| } else { | ||
| min1i = pos2 | ||
| pos2 += 1 | ||
| } | ||
| if (pos1 >= 0) { | ||
| if (count(pos1) < count(pos2)) { | ||
| min2i = pos1 | ||
| pos1 -= 1 | ||
| } else { | ||
| min2i = pos2 | ||
| pos2 += 1 | ||
| } | ||
| } else { | ||
| min2i = pos2 | ||
| pos2 += 1 | ||
| } | ||
| count(vocabSize + a) = count(min1i) + count(min2i) | ||
| parentNode(min1i) = vocabSize + a | ||
| parentNode(min2i) = vocabSize + a | ||
| binary(min2i) = 1 | ||
| a += 1 | ||
| } | ||
| // Now assign binary code to each vocabulary word | ||
| var i = 0 | ||
| a = 0 | ||
| while (a < vocabSize) { | ||
| var b = a | ||
| i = 0 | ||
| while (b != vocabSize * 2 - 2) { | ||
| code(i) = binary(b) | ||
| point(i) = b | ||
| i += 1 | ||
| b = parentNode(b) | ||
| } | ||
| vocab(a).codeLen = i | ||
| vocab(a).point(0) = vocabSize - 2 | ||
| b = 0 | ||
| while (b < i) { | ||
| vocab(a).code(i - b - 1) = code(b) | ||
| vocab(a).point(i - b) = point(b) - vocabSize | ||
| b += 1 | ||
| } | ||
| a += 1 | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Computes the vector representation of each word in | ||
| * vocabulary | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move |
||
| * @param dataset an RDD of strings | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need more information about what each record should be, a word, a sentence, or a paragraph? |
||
| * @return a Word2VecModel | ||
| */ | ||
|
|
||
| def fit(dataset:RDD[String]): Word2VecModel = { | ||
|
|
||
| learnVocab(dataset) | ||
|
|
||
| createBinaryTree() | ||
|
|
||
| val sc = dataset.context | ||
|
|
||
| val expTable = sc.broadcast(createExpTable()) | ||
| val V = sc.broadcast(vocab) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is |
||
| val VHash = sc.broadcast(vocabHash) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same issue here. |
||
|
|
||
| val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may be better if we use |
||
| iter => { new Iterator[Array[Int]] { | ||
| def hasNext = iter.hasNext | ||
| def next = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. insert an empty line between method definitions |
||
| var sentence = new ArrayBuffer[Int] | ||
| var sentenceLength = 0 | ||
| while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { | ||
| val word = VHash.value.get(iter.next) | ||
| word match { | ||
| case Some(w) => { | ||
| sentence += w | ||
| sentenceLength += 1 | ||
| } | ||
| case None => | ||
| } | ||
| } | ||
| sentence.toArray | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| val newSentences = sentences.repartition(1).cache() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we make it configurable? |
||
| val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try to fix the seed to make the computation reproducible. |
||
| val (aggSyn0, _, _, _) = | ||
| // TODO: broadcast temp instead of serializing it directly or initialize the model in each executor | ||
| newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( | ||
| seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => | ||
| var lwc = lastWordCount | ||
| var wc = wordCount | ||
| if (wordCount - lastWordCount > 10000) { | ||
| lwc = wordCount | ||
| alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1)) | ||
| if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 | ||
| logInfo("wordCount = " + wordCount + ", alpha = " + alpha) | ||
| } | ||
| wc += sentence.size | ||
| var pos = 0 | ||
| while (pos < sentence.size) { | ||
| val word = sentence(pos) | ||
| // TODO: fix random seed | ||
| val b = Random.nextInt(window) | ||
| // Train Skip-gram | ||
| var a = b | ||
| while (a < window * 2 + 1 - b) { | ||
| if (a != window) { | ||
| val c = pos - window + a | ||
| if (c >= 0 && c < sentence.size) { | ||
| val lastWord = sentence(c) | ||
| val l1 = lastWord * layer1Size | ||
| val neu1e = new Array[Double](layer1Size) | ||
| //HS | ||
| var d = 0 | ||
| while (d < vocab(word).codeLen) { | ||
| val l2 = vocab(word).point(d) * layer1Size | ||
| // Propagate hidden -> output | ||
| var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) | ||
| if (f > -MAX_EXP && f < MAX_EXP) { | ||
| val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt | ||
| f = expTable.value(ind) | ||
| val g = (1 - vocab(word).code(d) - f) * alpha | ||
| blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) | ||
| blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) | ||
| } | ||
| d += 1 | ||
| } | ||
| blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) | ||
| } | ||
| } | ||
| a += 1 | ||
| } | ||
| pos += 1 | ||
| } | ||
| (syn0, syn1, lwc, wc) | ||
| }, | ||
| combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line is definitely too long. Try running |
||
| val n = syn0_1.length | ||
| blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) | ||
| blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) | ||
| (syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) | ||
| }) | ||
|
|
||
| val wordMap = new Array[(String, Array[Double])](vocabSize) | ||
| var i = 0 | ||
| while (i < vocabSize) { | ||
| val word = vocab(i).word | ||
| val vector = new Array[Double](layer1Size) | ||
| Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size) | ||
| wordMap(i) = (word, vector) | ||
| i += 1 | ||
| } | ||
| val modelRDD = sc.parallelize(wordMap, modelPartitionNum).partitionBy(new HashPartitioner(modelPartitionNum)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line too wide |
||
| new Word2VecModel(modelRDD) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please call |
||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Word2Vec model | ||
| */ | ||
| class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| val model = _model | ||
|
|
||
| private def distance(v1: Array[Double], v2: Array[Double]): Double = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| require(v1.length == v2.length, "Vectors should have the same length") | ||
| val n = v1.length | ||
| val norm1 = blas.dnrm2(n, v1, 1) | ||
| val norm2 = blas.dnrm2(n, v2, 1) | ||
| if (norm1 == 0 || norm2 == 0) return 0.0 | ||
| blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 | ||
| } | ||
|
|
||
| /** | ||
| * Transforms a word to its vector representation | ||
| * @param word a word | ||
| * @return vector representation of word | ||
| */ | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove empty line |
||
| def transform(word: String): Array[Double] = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we use |
||
| val result = model.lookup(word) | ||
| if (result.isEmpty) Array[Double]() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should create an empty vector of the same size or throw an exception |
||
| else result(0) | ||
| } | ||
|
|
||
| /** | ||
| * Transforms an RDD to its vector representation | ||
| * @param dataset a an RDD of words | ||
| * @return RDD of vector representation | ||
| */ | ||
|
|
||
| def transform(dataset: RDD[String]): RDD[Array[Double]] = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return type should be |
||
| dataset.map(word => transform(word)) | ||
| } | ||
|
|
||
| /** | ||
| * Find synonyms of a word | ||
| * @param word a word | ||
| * @param num number of synonyms to find | ||
| * @return array of (word, similarity) | ||
| */ | ||
| def findSynonyms(word: String, num: Int): Array[(String, Double)] = { | ||
| val vector = transform(word) | ||
| if (vector.isEmpty) Array[(String, Double)]() | ||
| else findSynonyms(vector,num) | ||
| } | ||
|
|
||
| /** | ||
| * Find synonyms of the vector representation of a word | ||
| * @param vector vector representation of a word | ||
| * @param num number of synonyms to find | ||
| * @return array of (word, similarity) | ||
| */ | ||
| def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = { | ||
| require(num > 0, "Number of similar words should > 0") | ||
| val topK = model.map( | ||
| {case(w, vec) => (distance(vector, vec), w)}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| .sortByKey(ascending = false) | ||
| .take(num + 1) | ||
| .map({case (dist, w) => (w, dist)}).drop(1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| topK | ||
| } | ||
| } | ||
|
|
||
| object Word2Vec extends Serializable with Logging { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| /** | ||
| * Train Word2Vec model | ||
| * @param input RDD of words | ||
| * @param size vectoer dimension | ||
| * @param startingAlpha initial learning rate | ||
| * @param window context words from [-window, window] | ||
| * @param minCount minimum frequncy to consider a vocabulary word | ||
| * @return Word2Vec model | ||
| */ | ||
| def train( | ||
| input: RDD[String], | ||
| size: Int, | ||
| startingAlpha: Double, | ||
| window: Int, | ||
| minCount: Int): Word2VecModel = { | ||
| new Word2Vec(size,startingAlpha, window, minCount).fit(input) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import scala.util.Random