Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,19 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)

/** @group setParam */
@Since("1.2.0")
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
def setEstimator(value: Estimator[_]): this.type = setEstimators(Array(value))

/** @group setParam */
@Since("1.2.0")
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
def setEstimatorParamMaps(value: Array[ParamMap]): this.type =
setEstimatorsParamMaps(Array(value))

/** @group setParam */
def setEstimators(value: Array[Estimator[_]]): this.type = set(estimators, value)

/** @group setParam */
def setEstimatorsParamMaps(value: Array[Array[ParamMap]]): this.type =
set(estimatorsParamMaps, value)

/** @group setParam */
@Since("1.2.0")
Expand All @@ -96,42 +104,46 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val schema = dataset.schema
transformSchema(schema, logging = true)
val sparkSession = dataset.sparkSession
val est = $(estimator)
val ests = $(estimators)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
val epms = $(estimatorsParamMaps).flatten
val metrics = new Array[Double](getModelCount)
val modelToEstIndex = getModelToEstIndex

val instr = Instrumentation.create(this, dataset)
instr.logParams(numFolds, seed)
logTuningParams(instr)
ests.indices.foreach(logTuningParams(instr, _))

val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
val models = ests.zip($(estimatorsParamMaps))
.flatMap(estEpm => estEpm._1.fit(trainingDataset, estEpm._2).asInstanceOf[Seq[Model[_]]])
trainingDataset.unpersist()
var i = 0
while (i < numModels) {
while (i < getModelCount) {
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
val metric = eval.evaluate(models(i).transform(validationDataset, epms(i)))
logDebug(s"Got metric $metric for model trained with " +
s"${ests(modelToEstIndex(i))} and parameters ${epms(i)}.")
metrics(i) += metric
i += 1
}
validationDataset.unpersist()
}
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
f2jBLAS.dscal(getModelCount, 1.0 / $(numFolds), metrics, 1)
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
else metrics.zipWithIndex.minBy(_._1)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best estimator:\n${ests(modelToEstIndex(bestIndex))}")
logInfo(s"Best set of parameters:\n${epms(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
val bestModel = ests(modelToEstIndex(bestIndex))
.fit(dataset, epms(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}
Expand All @@ -142,8 +154,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
override def copy(extra: ParamMap): CrossValidator = {
val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
if (copied.isDefined(estimator)) {
copied.setEstimator(copied.getEstimator.copy(extra))
if (copied.isDefined(estimators)) {
copied.setEstimators(copied.getEstimators.map(_.copy(extra)))
}
if (copied.isDefined(evaluator)) {
copied.setEvaluator(copied.getEvaluator.copy(extra))
Expand Down Expand Up @@ -183,14 +195,14 @@ object CrossValidator extends MLReadable[CrossValidator] {
override def load(path: String): CrossValidator = {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
val (metadata, estimators, evaluator, estimatorsParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val seed = (metadata.params \ "seed").extract[Long]
new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEstimators(estimators)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setEstimatorsParamMaps(estimatorsParamMaps)
.setNumFolds(numFolds)
.setSeed(seed)
}
Expand Down Expand Up @@ -273,17 +285,17 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
override def load(path: String): CrossValidatorModel = {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
val (metadata, estimators, evaluator, estimatorsParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
model.set(model.estimator, estimator)
model.set(model.estimators, estimators)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.estimatorsParamMaps, estimatorsParamMaps)
.set(model.numFolds, numFolds)
.set(model.seed, seed)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,24 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St

/** @group setParam */
@Since("1.5.0")
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
def setEstimator(value: Estimator[_]): this.type = setEstimators(Array(value))

/** @group setParam */
@Since("1.5.0")
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
def setEstimatorParamMaps(value: Array[ParamMap]): this.type =
setEstimatorsParamMaps(Array(value))

/** @group setParam */
@Since("1.5.0")
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)

/** @group setParam */
def setEstimators(value: Array[Estimator[_]]): this.type = set(estimators, value)

/** @group setParam */
def setEstimatorsParamMaps(value: Array[Array[ParamMap]]): this.type =
set(estimatorsParamMaps, value)

/** @group setParam */
@Since("1.5.0")
def setTrainRatio(value: Double): this.type = set(trainRatio, value)
Expand All @@ -91,15 +99,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
val est = $(estimator)
val ests = $(estimators)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
val epms = $(estimatorsParamMaps).flatten
val metrics = new Array[Double](getModelCount)
val modelToEstIndex = getModelToEstIndex

val instr = Instrumentation.create(this, dataset)
instr.logParams(trainRatio, seed)
logTuningParams(instr)
ests.indices.foreach(logTuningParams(instr, _))

val Array(trainingDataset, validationDataset) =
dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
Expand All @@ -108,13 +116,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St

// multi-model training
logDebug(s"Train split with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
val models = ests.zip($(estimatorsParamMaps))
.flatMap(estEpm => estEpm._1.fit(trainingDataset, estEpm._2).asInstanceOf[Seq[Model[_]]])
trainingDataset.unpersist()
var i = 0
while (i < numModels) {
while (i < getModelCount) {
// TODO: duplicate evaluator to take extra params from input
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
val metric = eval.evaluate(models(i).transform(validationDataset, epms(i)))
logDebug(s"Got metric $metric for model trained with " +
s"${ests(modelToEstIndex(i))} and parameters ${epms(i)}.")
metrics(i) += metric
i += 1
}
Expand All @@ -124,9 +134,11 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
else metrics.zipWithIndex.minBy(_._1)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best estimator:\n${ests(modelToEstIndex(bestIndex))}")
logInfo(s"Best set of parameters:\n${epms(bestIndex)}")
logInfo(s"Best train validation split metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
val bestModel = ests(modelToEstIndex(bestIndex))
.fit(dataset, epms(bestIndex)).asInstanceOf[Model[_]]
instr.logSuccess(bestModel)
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
}
Expand All @@ -137,8 +149,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("1.5.0")
override def copy(extra: ParamMap): TrainValidationSplit = {
val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit]
if (copied.isDefined(estimator)) {
copied.setEstimator(copied.getEstimator.copy(extra))
if (copied.isDefined(estimators)) {
copied.setEstimators(copied.getEstimators.map(_.copy(extra)))
}
if (copied.isDefined(evaluator)) {
copied.setEvaluator(copied.getEvaluator.copy(extra))
Expand Down Expand Up @@ -176,14 +188,14 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
override def load(path: String): TrainValidationSplit = {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
val (metadata, estimators, evaluator, estimatorsParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
val seed = (metadata.params \ "seed").extract[Long]
new TrainValidationSplit(metadata.uid)
.setEstimator(estimator)
.setEstimators(estimators)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setEstimatorsParamMaps(estimatorsParamMaps)
.setTrainRatio(trainRatio)
.setSeed(seed)
}
Expand Down Expand Up @@ -264,17 +276,17 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
override def load(path: String): TrainValidationSplitModel = {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
val (metadata, estimators, evaluator, estimatorsParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
model.set(model.estimator, estimator)
model.set(model.estimators, estimators)
.set(model.evaluator, evaluator)
.set(model.estimatorParamMaps, estimatorParamMaps)
.set(model.estimatorsParamMaps, estimatorsParamMaps)
.set(model.trainRatio, trainRatio)
.set(model.seed, seed)
}
Expand Down
Loading