Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
Expand Down Expand Up @@ -339,14 +340,22 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
val wordVectors = instance.wordVectors.getVectors
val dataSeq = wordVectors.toSeq.map { case (word, vector) => Data(word, vector) }
val dataPath = new Path(path, "data").toString
val numPartitions = Word2VecModelWriter.calculateNumberOfPartitions(
sc, instance.wordVectors.wordIndex.size, instance.getVectorSize)
sparkSession.createDataFrame(dataSeq)
.repartition(calculateNumberOfPartitions)
.repartition(numPartitions)
.write
.parquet(dataPath)
}
}

def calculateNumberOfPartitions(): Int = {
val floatSize = 4
private[feature]
object Word2VecModelWriter {
def calculateNumberOfPartitions(
sc: SparkContext,
numWords: Int,
vectorSize: Int): Int = {
val floatSize = 4L // Use Long to help avoid overflow
val averageWordSize = 15
// [SPARK-11994] - We want to partition the model in partitions smaller than
// spark.kryoserializer.buffer.max
Expand All @@ -355,9 +364,12 @@ object Word2VecModel extends MLReadable[Word2VecModel] {
// Calculate the approximate size of the model.
// Assuming an average word size of 15 bytes, the formula is:
// (floatSize * vectorSize + 15) * numWords
val numWords = instance.wordVectors.wordIndex.size
val approximateSizeInBytes = (floatSize * instance.getVectorSize + averageWordSize) * numWords
((approximateSizeInBytes / bufferSizeInBytes) + 1).toInt
val approximateSizeInBytes = (floatSize * vectorSize + averageWordSize) * numWords
val numPartitions = (approximateSizeInBytes / bufferSizeInBytes) + 1
require(numPartitions < 10e8, s"Word2VecModel calculated that it needs $numPartitions " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this failure truly necessary? Can we make this a WARN and use a best-attempt (Int.MAX?) partitions instead? The models that would fail here would be so huge that they likely took days to train

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure it is necessary. If we cap it at Int.MAX and the user hits that cap, then it means that we'll fail when trying to write the partitions.

s"partitions to save this model, which is too large. Try increasing " +
s"spark.kryoserializer.buffer.max so that Word2VecModel can use fewer partitions.")
numPartitions.toInt
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
}

test("Word2Vec read/write numPartitions calculation") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this test hardcode a specific spark.kryoserializer.buffer.max? That could allow us to be more explicit in the assertions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point; I'll do that.

val tinyModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
sc, numWords = 10, vectorSize = 5)
assert(tinyModelNumPartitions === 1)
val mediumModelNumPartitions = Word2VecModel.Word2VecModelWriter.calculateNumberOfPartitions(
sc, numWords = 1000000, vectorSize = 5000)
assert(mediumModelNumPartitions > 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a test for a truly large model that would have otherwise caused an overflow?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "medium" one did cause an overflow.

}

test("Word2Vec read/write") {
val t = new Word2Vec()
.setInputCol("myInputCol")
Expand Down