Skip to content

Commit

Permalink
MixtureMachine API
Browse files Browse the repository at this point in the history
  • Loading branch information
mandar2812 committed Jun 22, 2017
1 parent 45adb7b commit fd818b6
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ import breeze.stats.distributions.{ContinuousDistr, Moments}
import io.github.mandar2812.dynaml.algebra.{PartitionedPSDMatrix, PartitionedVector}
import io.github.mandar2812.dynaml.models.gp.AbstractGPRegressionModel
import io.github.mandar2812.dynaml.models.stp.{AbstractSTPRegressionModel, MVStudentsTModel}
import io.github.mandar2812.dynaml.models.{ContinuousProcessModel, GenContinuousMixtureModel, SecondOrderProcessModel, StochasticProcessMixtureModel}
import io.github.mandar2812.dynaml.models.{
ContinuousProcessModel, GenContinuousMixtureModel,
SecondOrderProcessModel, StochasticProcessMixtureModel}
import io.github.mandar2812.dynaml.optimization.GloballyOptimizable
import io.github.mandar2812.dynaml.pipes.DataPipe2
import io.github.mandar2812.dynaml.probability.{ContinuousRVWithDistr, MatrixTRV, MultGaussianPRV, MultStudentsTPRV}
import io.github.mandar2812.dynaml.probability.distributions.{BlockedMultiVariateGaussian, BlockedMultivariateStudentsT, HasErrorBars, MatrixT}
import io.github.mandar2812.dynaml.probability.distributions.{
BlockedMultiVariateGaussian, BlockedMultivariateStudentsT,
HasErrorBars, MatrixT}

import scala.reflect.ClassTag

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ import scala.reflect.ClassTag

/**
* ::Experimental::
* @author mandar date 02/01/2017.
*
* A warped Gaussian Process.
*/
* @author mandar date 02/01/2017.
*
* */
@Experimental
class WarpedGPModel[T, I:ClassTag](p: AbstractGPRegressionModel[T, I])(
warpingFunc: PushforwardMap[Double, Double, Double])(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ import scala.reflect.ClassTag
* @tparam I The index set/input domain of the GP model.
* @author mandar2812 date 15/06/2017.
* */
class ProbGPMixtureMachine[T, I: ClassTag](
class GPMixtureMachine[T, I: ClassTag](
model: AbstractGPRegressionModel[T, I]) extends
MixtureMachine[T, I, Double, PartitionedVector, PartitionedPSDMatrix, BlockedMultiVariateGaussian,
MultGaussianPRV, AbstractGPRegressionModel[T, I]](model) {
MixtureMachine[
T, I, Double, PartitionedVector, PartitionedPSDMatrix,
BlockedMultiVariateGaussian, MultGaussianPRV,
AbstractGPRegressionModel[T, I]](model) {

val (kernelPipe, noisePipe) = (system.covariance.asPipe, system.noiseModel.asPipe)

Expand All @@ -33,7 +35,8 @@ class ProbGPMixtureMachine[T, I: ClassTag](
(model_state: Map[String, Double]) =>
AbstractGPRegressionModel(
kernelPipe(model_state), noisePipe(model_state),
system.mean)(system.data, system.npoints))
system.mean)(system.data, system.npoints)
)

override val mixturePipe = new GPMixturePipe[T, I]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,24 @@ BaseProcess <: ContinuousProcessModel[T, I, Y, W1]
)
}

}

object MixtureMachine {

def apply[
T, I: ClassTag, Y, YDomain, YDomainVar,
BaseDistr <: ContinuousDistr[YDomain] with Moments[YDomain, YDomainVar] with HasErrorBars[YDomain],
W1 <: ContinuousRVWithDistr[YDomain, BaseDistr],
BaseProcess <: ContinuousProcessModel[T, I, Y, W1]
with SecondOrderProcessModel[T, I, Y, Double, DenseMatrix[Double], W1]
with GloballyOptimizable](model: BaseProcess)(
confModelPipe: DataPipe[Map[String, Double], BaseProcess],
mixtPipe: DataPipe2[Seq[BaseProcess], DenseVector[Double], GenContinuousMixtureModel[
T, I, Y, YDomain, YDomainVar,
BaseDistr, W1, BaseProcess]]) =
new MixtureMachine[T, I, Y, YDomain, YDomainVar, BaseDistr, W1, BaseProcess](model) {
override val confToModel = confModelPipe
override val mixturePipe = mixtPipe
}

}
4 changes: 2 additions & 2 deletions scripts/stochasticPriors.sc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.github.mandar2812.dynaml.models.bayes.{LinearTrendESGPrior, LinearTren
import io.github.mandar2812.dynaml.probability._
import com.quantifind.charts.Highcharts._
import io.github.mandar2812.dynaml.analysis.implicits._
import io.github.mandar2812.dynaml.optimization.ProbGPMixtureMachine
import io.github.mandar2812.dynaml.optimization.GPMixtureMachine
import io.github.mandar2812.dynaml.pipes.Encoder
import io.github.mandar2812.dynaml.probability.distributions.UnivariateGaussian

Expand Down Expand Up @@ -63,7 +63,7 @@ val sgpModel = sgp_prior.posteriorModel(dataset)
gp_prior.globalOptConfig_(Map("gridStep" -> "0.0", "gridSize" -> "1", "globalOpt" -> "GS", "policy" -> "GS"))
val gpModel1 = gp_prior.posteriorModel(dataset)

val mixt_machine = new ProbGPMixtureMachine(gpModel1).setPrior(hyp_prior).setGridSize(2).setStepSize(0.50).setLogScale(true).setMaxIterations(200).setNumSamples(3)
val mixt_machine = new GPMixtureMachine(gpModel1).setPrior(hyp_prior).setGridSize(2).setStepSize(0.50).setLogScale(true).setMaxIterations(200).setNumSamples(3)

val (mix_model, mixt_model_conf) = mixt_machine.optimize(gp_prior.covariance.effective_state ++ gp_prior.noiseCovariance.effective_state)

Expand Down

0 comments on commit fd818b6

Please sign in to comment.