@@ -46,19 +46,19 @@ class NaiveBayesModel private[mllib] (
4646 val labels : Array [Double ],
4747 val pi : Array [Double ],
4848 val theta : Array [Array [Double ]],
49- val modelType : String )
49+ val modelType : NaiveBayes . ModelType )
5050 extends ClassificationModel with Serializable with Saveable {
5151
5252 private [mllib] def this (labels : Array [Double ], pi : Array [Double ], theta : Array [Array [Double ]]) =
53- this (labels, pi, theta, NaiveBayes .Multinomial .toString )
53+ this (labels, pi, theta, NaiveBayes .Multinomial )
5454
5555 private val brzPi = new BDV [Double ](pi)
5656 private val brzTheta = new BDM (theta(0 ).length, theta.length, theta.flatten).t
5757
5858 // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
5959 // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
6060 // application of this condition (in predict function).
61- private val (brzNegTheta, brzNegThetaSum) = NaiveBayes . ModelType .fromString( modelType) match {
61+ private val (brzNegTheta, brzNegThetaSum) = modelType match {
6262 case NaiveBayes .Multinomial => (None , None )
6363 case NaiveBayes .Bernoulli =>
6464 val negTheta = brzLog((brzExp(brzTheta.copy) :*= (- 1.0 )) :+= 1.0 ) // log(1.0 - exp(x))
@@ -74,7 +74,7 @@ class NaiveBayesModel private[mllib] (
7474 }
7575
7676 override def predict (testData : Vector ): Double = {
77- NaiveBayes . ModelType .fromString( modelType) match {
77+ modelType match {
7878 case NaiveBayes .Multinomial =>
7979 labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
8080 case NaiveBayes .Bernoulli =>
@@ -84,7 +84,7 @@ class NaiveBayesModel private[mllib] (
8484 }
8585
8686 override def save (sc : SparkContext , path : String ): Unit = {
87- val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType)
87+ val data = NaiveBayesModel .SaveLoadV1_0 .Data (labels, pi, theta, modelType.toString )
8888 NaiveBayesModel .SaveLoadV1_0 .save(sc, path, data)
8989 }
9090
@@ -137,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
137137 val labels = data.getAs[Seq [Double ]](0 ).toArray
138138 val pi = data.getAs[Seq [Double ]](1 ).toArray
139139 val theta = data.getAs[Seq [Seq [Double ]]](2 ).map(_.toArray).toArray
140- val modelType = NaiveBayes .ModelType .fromString(data.getString(3 )).toString
140+ val modelType = NaiveBayes .ModelType .fromString(data.getString(3 ))
141141 new NaiveBayesModel (labels, pi, theta, modelType)
142142 }
143143 }
144144
145145 override def load (sc : SparkContext , path : String ): NaiveBayesModel = {
146- def getModelType (metadata : JValue ): String = {
146+ def getModelType (metadata : JValue ): NaiveBayes . ModelType = {
147147 implicit val formats = DefaultFormats
148- NaiveBayes .ModelType .fromString((metadata \ " modelType" ).extract[String ]).toString
148+ NaiveBayes .ModelType .fromString((metadata \ " modelType" ).extract[String ])
149149 }
150150 val (loadedClassName, version, metadata) = loadMetadata(sc, path)
151151 val classNameV1_0 = SaveLoadV1_0 .thisClassName
@@ -202,7 +202,7 @@ class NaiveBayes private (
202202 this
203203 }
204204
205- def getModelType () : NaiveBayes .ModelType = this .modelType
205+ def getModelType : NaiveBayes .ModelType = this .modelType
206206
207207 /**
208208 * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
@@ -266,7 +266,7 @@ class NaiveBayes private (
266266 i += 1
267267 }
268268
269- new NaiveBayesModel (labels, pi, theta, modelType.toString )
269+ new NaiveBayesModel (labels, pi, theta, modelType)
270270 }
271271}
272272
@@ -328,9 +328,9 @@ object NaiveBayes {
328328 }
329329
330330 /** Provides static methods for using ModelType. */
331- sealed abstract class ModelType
331+ sealed abstract class ModelType extends Serializable
332332
333- object MODELTYPE {
333+ object MODELTYPE extends Serializable {
334334 final val MULTINOMIAL_STRING = " multinomial"
335335 final val BERNOULLI_STRING = " bernoulli"
336336
0 commit comments