diff --git a/src/Forest.cpp b/src/Forest.cpp index 6d70aa27..9a3e08b8 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -41,6 +41,10 @@ void Forest::init(std::unique_ptr input_data, PredType pred_type, bool pred_mode, bool pred_aggregate, + PartialDepType pd_type, + std::vector& pd_x_vals, + std::vector& pd_x_cols, + arma::vec& pd_probs, bool oobag_pred, EvalType oobag_eval_type, arma::uword oobag_eval_every, @@ -73,6 +77,10 @@ void Forest::init(std::unique_ptr input_data, this->pred_type = pred_type; this->pred_mode = pred_mode; this->pred_aggregate = pred_aggregate; + this->pd_type = pd_type; + this->pd_x_vals = pd_x_vals; + this->pd_x_cols = pd_x_cols; + this->pd_probs = pd_probs; this->oobag_pred = oobag_pred; this->oobag_eval_type = oobag_eval_type; this->oobag_eval_every = oobag_eval_every; diff --git a/src/Forest.h b/src/Forest.h index 39a5ee58..94734ea5 100644 --- a/src/Forest.h +++ b/src/Forest.h @@ -63,6 +63,10 @@ class Forest { PredType pred_type, bool pred_mode, bool pred_aggregate, + PartialDepType pd_type, + std::vector& pd_x_vals, + std::vector& pd_x_cols, + arma::vec& pd_probs, bool oobag_pred, EvalType oobag_eval_type, arma::uword oobag_eval_every, diff --git a/src/orsf_oop.cpp b/src/orsf_oop.cpp index 16010188..f22b401a 100644 --- a/src/orsf_oop.cpp +++ b/src/orsf_oop.cpp @@ -335,6 +335,10 @@ pred_type, pred_mode, pred_aggregate, + pd_type, + pd_x_vals, + pd_x_cols, + pd_probs, oobag, oobag_eval_type, oobag_eval_every,