From aa2ab36c6a9fa22e759ccb99352394dd6d6317e0 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Wed, 13 Aug 2014 05:45:17 -0700 Subject: [PATCH 1/5] use reduceByKey to combine models --- .../apache/spark/mllib/feature/Word2Vec.scala | 65 ++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index ecd49ea2ff53..3ec1b3407a87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -34,7 +34,7 @@ import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom - +import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap /** * Entry in vocabulary */ @@ -292,6 +292,9 @@ class Word2Vec extends Serializable with Logging { for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val syn0Modify = new Array[Int](vocabSize) + val syn1Modify = new Array[Int](vocabSize) + val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount @@ -321,8 +324,10 @@ class Word2Vec extends Serializable with Logging { // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { + val ind = bcVocab.value(word).point(d) val l2 = bcVocab.value(word).point(d) * vectorSize // Propagate hidden -> output + syn1Modify(ind) += 1 var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt @@ -334,6 +339,7 @@ class Word2Vec extends Serializable with Logging { d += 1 } blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + syn0Modify(lastWord) += 1 } } a += 1 @@ -342,21 +348,50 @@ class Word2Vec extends Serializable with Logging { } (syn0, syn1, lwc, wc) } - Iterator(model) - } - val (aggSyn0, aggSyn1, _, _) = - partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + val syn0Local = model._1 + val syn1Local = model._2 + + val syn0Out = new PrimitiveKeyOpenHashMap[Int, Array[Float]] + // val syn1Out = new PrimitiveKeyOpenHashMap[Int, Array[Float]] + var index = 0 + while(index < vocabSize) { + if (syn0Modify(index) != 0) syn0Out.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) + if (syn1Modify(index) != 0) syn0Out.update(index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + index += 1 } - syn0Global = aggSyn0 - syn1Global = aggSyn1 + Iterator(syn0Out) + } + // partial.cache() + + val synAgg = partial.flatMap(x => x).reduceByKey { + case (v1,v2) => + blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + v1 + }.collect().sortBy(_._1).flatMap(x => x._2) + + syn0Global = synAgg.slice(0, vocabSize * vectorSize) + syn1Global = synAgg.slice(vocabSize * vectorSize, synAgg.length) + //logInfo("syn0Global length = " + syn0Global.length) + // syn1Global = partial.flatMap(x => x._2).reduceByKey { + // case (v1,v2) => + // blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + // v1 + // }.collect().sortBy(_._1).flatMap(x => x._2) + // logInfo("syn1Global length = " + syn1Global.length) + // logInfo("vocab size = " + vocabSize) + // val (aggSyn0, aggSyn1, _, _) = + // partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + // val n = syn0_1.length + // val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + // val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + // blas.sscal(n, weight1, syn0_1, 1) + // blas.sscal(n, weight1, syn1_1, 1) + // blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + // blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) + // (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + // } + //syn0Global = aggSyn0 + //syn1Global = aggSyn1 } newSentences.unpersist() From 9075e1cba5ae64add2986514be99dc51083ff177 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Wed, 13 Aug 2014 16:17:58 -0700 Subject: [PATCH 2/5] combine syn0Global and syn1Global to synGlobal --- .../apache/spark/mllib/feature/Word2Vec.scala | 78 ++++++------------- 1 file changed, 23 insertions(+), 55 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 3ec1b3407a87..8cfd13a837c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -235,7 +235,7 @@ class Word2Vec extends Serializable with Logging { b = 0 while (b < i) { vocab(a).code(i - b - 1) = code(b) - vocab(a).point(i - b) = point(b) - vocabSize + vocab(a).point(i - b) = point(b) b += 1 } a += 1 @@ -284,19 +284,15 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - var syn0Global = - Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) - var syn1Global = new Array[Float](vocabSize * vectorSize) - + var synGlobal = + Array.fill[Float](2 * vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) - val syn0Modify = new Array[Int](vocabSize) - val syn1Modify = new Array[Int](vocabSize) - - val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { - case ((syn0, syn1, lastWordCount, wordCount), sentence) => + val synModify = new Array[Int](2 * vocabSize) + val model = iter.foldLeft((synGlobal, 0, 0)) { + case ((syn, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount if (wordCount - lastWordCount > 10000) { @@ -325,73 +321,45 @@ class Word2Vec extends Serializable with Logging { var d = 0 while (d < bcVocab.value(word).codeLen) { val ind = bcVocab.value(word).point(d) - val l2 = bcVocab.value(word).point(d) * vectorSize + val l2 = ind * vectorSize // Propagate hidden -> output - syn1Modify(ind) += 1 - var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) + synModify(ind) += 1 + var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) - blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1) + blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1) } d += 1 } - blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) - syn0Modify(lastWord) += 1 + blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn, l1, 1) + synModify(lastWord) += 1 } } a += 1 } pos += 1 } - (syn0, syn1, lwc, wc) + (syn, lwc, wc) } - val syn0Local = model._1 - val syn1Local = model._2 - - val syn0Out = new PrimitiveKeyOpenHashMap[Int, Array[Float]] - // val syn1Out = new PrimitiveKeyOpenHashMap[Int, Array[Float]] + val synLocal = model._1 + val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) var index = 0 - while(index < vocabSize) { - if (syn0Modify(index) != 0) syn0Out.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) - if (syn1Modify(index) != 0) syn0Out.update(index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) + while(index < 2 * vocabSize) { + if (synModify(index) != 0) { + synOut.update(index, synLocal.slice(index * vectorSize, (index + 1) * vectorSize)) + } index += 1 } - Iterator(syn0Out) + Iterator(synOut) } - // partial.cache() - - val synAgg = partial.flatMap(x => x).reduceByKey { + synGlobal = partial.flatMap(x => x).reduceByKey { case (v1,v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) v1 }.collect().sortBy(_._1).flatMap(x => x._2) - - syn0Global = synAgg.slice(0, vocabSize * vectorSize) - syn1Global = synAgg.slice(vocabSize * vectorSize, synAgg.length) - //logInfo("syn0Global length = " + syn0Global.length) - // syn1Global = partial.flatMap(x => x._2).reduceByKey { - // case (v1,v2) => - // blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) - // v1 - // }.collect().sortBy(_._1).flatMap(x => x._2) - // logInfo("syn1Global length = " + syn1Global.length) - // logInfo("vocab size = " + vocabSize) - // val (aggSyn0, aggSyn1, _, _) = - // partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - // val n = syn0_1.length - // val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - // val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - // blas.sscal(n, weight1, syn0_1, 1) - // blas.sscal(n, weight1, syn1_1, 1) - // blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - // blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - // (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) - // } - //syn0Global = aggSyn0 - //syn1Global = aggSyn1 } newSentences.unpersist() @@ -400,7 +368,7 @@ class Word2Vec extends Serializable with Logging { while (i < vocabSize) { val word = bcVocab.value(i).word val vector = new Array[Float](vectorSize) - Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) + Array.copy(synGlobal, i * vectorSize, vector, 0, vectorSize) word2VecMap += word -> vector i += 1 } From 083aa66a5822e5169e4b7fe067e8c0735ffc27cb Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Thu, 14 Aug 2014 10:26:02 -0700 Subject: [PATCH 3/5] update synGlobal in place and reduce synOut size --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 8cfd13a837c9..de0c692d2a5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap + /** * Entry in vocabulary */ @@ -323,7 +324,6 @@ class Word2Vec extends Serializable with Logging { val ind = bcVocab.value(word).point(d) val l2 = ind * vectorSize // Propagate hidden -> output - synModify(ind) += 1 var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt @@ -331,6 +331,7 @@ class Word2Vec extends Serializable with Logging { val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1) blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1) + synModify(ind) += 1 } d += 1 } @@ -355,11 +356,15 @@ class Word2Vec extends Serializable with Logging { } Iterator(synOut) } - synGlobal = partial.flatMap(x => x).reduceByKey { - case (v1,v2) => + val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) v1 - }.collect().sortBy(_._1).flatMap(x => x._2) + }.collect() + var i = 0 + while (i < synAgg.length) { + Array.copy(synAgg(i)._2, 0, synGlobal, synAgg(i)._1 * vectorSize, vectorSize) + i += 1 + } } newSentences.unpersist() From cad201140970df237fc5492691c774e4d2d83763 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Thu, 14 Aug 2014 11:54:59 -0700 Subject: [PATCH 4/5] bug fix for synModify array out of bound --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index de0c692d2a5b..4d965e42c728 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -321,8 +321,8 @@ class Word2Vec extends Serializable with Logging { // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { - val ind = bcVocab.value(word).point(d) - val l2 = ind * vectorSize + val inner = bcVocab.value(word).point(d) + val l2 = inner * vectorSize // Propagate hidden -> output var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { @@ -331,7 +331,7 @@ class Word2Vec extends Serializable with Logging { val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1) blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1) - synModify(ind) += 1 + synModify(inner) += 1 } d += 1 } From d5377a9ea607d015fce4a2ac7eebdb467db5f46f Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sat, 16 Aug 2014 23:08:28 -0700 Subject: [PATCH 5/5] use syn0Global and syn1Global to represent model --- .../apache/spark/mllib/feature/Word2Vec.scala | 52 ++++++++++++------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 4d965e42c728..d2ae62b482af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -236,7 +236,7 @@ class Word2Vec extends Serializable with Logging { b = 0 while (b < i) { vocab(a).code(i - b - 1) = code(b) - vocab(a).point(i - b) = point(b) + vocab(a).point(i - b) = point(b) - vocabSize b += 1 } a += 1 @@ -285,15 +285,17 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) - var synGlobal = - Array.fill[Float](2 * vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) + var syn0Global = + Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) + var syn1Global = new Array[Float](vocabSize * vectorSize) var alpha = startingAlpha for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) - val synModify = new Array[Int](2 * vocabSize) - val model = iter.foldLeft((synGlobal, 0, 0)) { - case ((syn, lastWordCount, wordCount), sentence) => + val syn0Modify = new Array[Int](vocabSize) + val syn1Modify = new Array[Int](vocabSize) + val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { + case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount if (wordCount - lastWordCount > 10000) { @@ -324,45 +326,55 @@ class Word2Vec extends Serializable with Logging { val inner = bcVocab.value(word).point(d) val l2 = inner * vectorSize // Propagate hidden -> output - var f = blas.sdot(vectorSize, syn, l1, 1, syn, l2, 1) + var f = blas.sdot(vectorSize, syn0, l1, 1, syn1, l2, 1) if (f > -MAX_EXP && f < MAX_EXP) { val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt f = expTable.value(ind) val g = ((1 - bcVocab.value(word).code(d) - f) * alpha).toFloat - blas.saxpy(vectorSize, g, syn, l2, 1, neu1e, 0, 1) - blas.saxpy(vectorSize, g, syn, l1, 1, syn, l2, 1) - synModify(inner) += 1 + blas.saxpy(vectorSize, g, syn1, l2, 1, neu1e, 0, 1) + blas.saxpy(vectorSize, g, syn0, l1, 1, syn1, l2, 1) + syn1Modify(inner) += 1 } d += 1 } - blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn, l1, 1) - synModify(lastWord) += 1 + blas.saxpy(vectorSize, 1.0f, neu1e, 0, 1, syn0, l1, 1) + syn0Modify(lastWord) += 1 } } a += 1 } pos += 1 } - (syn, lwc, wc) + (syn0, syn1, lwc, wc) } - val synLocal = model._1 + val syn0Local = model._1 + val syn1Local = model._2 val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2) var index = 0 - while(index < 2 * vocabSize) { - if (synModify(index) != 0) { - synOut.update(index, synLocal.slice(index * vectorSize, (index + 1) * vectorSize)) + while(index < vocabSize) { + if (syn0Modify(index) != 0) { + synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)) + } + if (syn1Modify(index) != 0) { + synOut.update(index + vocabSize, + syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)) } index += 1 } Iterator(synOut) } val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) => - blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) + blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1) v1 }.collect() var i = 0 while (i < synAgg.length) { - Array.copy(synAgg(i)._2, 0, synGlobal, synAgg(i)._1 * vectorSize, vectorSize) + val index = synAgg(i)._1 + if (index < vocabSize) { + Array.copy(synAgg(i)._2, 0, syn0Global, index * vectorSize, vectorSize) + } else { + Array.copy(synAgg(i)._2, 0, syn1Global, (index - vocabSize) * vectorSize, vectorSize) + } i += 1 } } @@ -373,7 +385,7 @@ class Word2Vec extends Serializable with Logging { while (i < vocabSize) { val word = bcVocab.value(i).word val vector = new Array[Float](vectorSize) - Array.copy(synGlobal, i * vectorSize, vector, 0, vectorSize) + Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) word2VecMap += word -> vector i += 1 }