Skip to content

Commit a22d670

Browse files
committed
changed NaiveBayesModel modelType parameter back to NaiveBayes.ModelType, made NaiveBayes.ModelType serializable, fixed getter method in NavieBayes
1 parent 18f3219 commit a22d670

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,19 @@ class NaiveBayesModel private[mllib] (
4646
val labels: Array[Double],
4747
val pi: Array[Double],
4848
val theta: Array[Array[Double]],
49-
val modelType: String)
49+
val modelType: NaiveBayes.ModelType)
5050
extends ClassificationModel with Serializable with Saveable {
5151

5252
private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
53-
this(labels, pi, theta, NaiveBayes.Multinomial.toString)
53+
this(labels, pi, theta, NaiveBayes.Multinomial)
5454

5555
private val brzPi = new BDV[Double](pi)
5656
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t
5757

5858
// Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
5959
// This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
6060
// application of this condition (in predict function).
61-
private val (brzNegTheta, brzNegThetaSum) = NaiveBayes.ModelType.fromString(modelType) match {
61+
private val (brzNegTheta, brzNegThetaSum) = modelType match {
6262
case NaiveBayes.Multinomial => (None, None)
6363
case NaiveBayes.Bernoulli =>
6464
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
@@ -74,7 +74,7 @@ class NaiveBayesModel private[mllib] (
7474
}
7575

7676
override def predict(testData: Vector): Double = {
77-
NaiveBayes.ModelType.fromString(modelType) match {
77+
modelType match {
7878
case NaiveBayes.Multinomial =>
7979
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
8080
case NaiveBayes.Bernoulli =>
@@ -84,7 +84,7 @@ class NaiveBayesModel private[mllib] (
8484
}
8585

8686
override def save(sc: SparkContext, path: String): Unit = {
87-
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType)
87+
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
8888
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
8989
}
9090

@@ -137,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
137137
val labels = data.getAs[Seq[Double]](0).toArray
138138
val pi = data.getAs[Seq[Double]](1).toArray
139139
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
140-
val modelType = NaiveBayes.ModelType.fromString(data.getString(3)).toString
140+
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
141141
new NaiveBayesModel(labels, pi, theta, modelType)
142142
}
143143
}
144144

145145
override def load(sc: SparkContext, path: String): NaiveBayesModel = {
146-
def getModelType(metadata: JValue): String = {
146+
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
147147
implicit val formats = DefaultFormats
148-
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String]).toString
148+
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
149149
}
150150
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
151151
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -202,7 +202,7 @@ class NaiveBayes private (
202202
this
203203
}
204204

205-
def getModelType(): NaiveBayes.ModelType = this.modelType
205+
def getModelType: NaiveBayes.ModelType = this.modelType
206206

207207
/**
208208
* Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -266,7 +266,7 @@ class NaiveBayes private (
266266
i += 1
267267
}
268268

269-
new NaiveBayesModel(labels, pi, theta, modelType.toString)
269+
new NaiveBayesModel(labels, pi, theta, modelType)
270270
}
271271
}
272272

@@ -328,9 +328,9 @@ object NaiveBayes {
328328
}
329329

330330
/** Provides static methods for using ModelType. */
331-
sealed abstract class ModelType
331+
sealed abstract class ModelType extends Serializable
332332

333-
object MODELTYPE {
333+
object MODELTYPE extends Serializable{
334334
final val MULTINOMIAL_STRING = "multinomial"
335335
final val BERNOULLI_STRING = "bernoulli"
336336

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ object NaiveBayesSuite {
7676

7777
/** Binary labels, 3 features */
7878
private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
79-
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli.toString)
79+
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli)
8080
}
8181

8282
class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {

0 commit comments

Comments
 (0)