Skip to content

Commit 5711788

Browse files
committed
add proper error handling
1 parent fc7f2e6 commit 5711788

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/tree/decision_tree_classifier.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -903,14 +903,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
903903
///
904904
/// # Errors
905905
///
906-
/// Returns an error if the prediction process fails.
906+
/// Returns an error if at least one row prediction process fails.
907907
pub fn predict_proba(&self, x: &X) -> Result<DenseMatrix<f64>, Failed> {
908908
let (n_samples, _) = x.shape();
909909
let n_classes = self.classes().len();
910910
let mut result = DenseMatrix::<f64>::zeros(n_samples, n_classes);
911911

912912
for i in 0..n_samples {
913-
let probs = self.predict_proba_for_row(x, i);
913+
let probs = self.predict_proba_for_row(x, i)?;
914914
for (j, &prob) in probs.iter().enumerate() {
915915
result.set((i, j), prob);
916916
}
@@ -930,15 +930,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
930930
///
931931
/// A vector of probabilities, one for each class, representing the probability
932932
/// of the input sample belonging to each class.
933-
fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec<f64> {
933+
fn predict_proba_for_row(&self, x: &X, row: usize) -> Result<Vec<f64>, Failed> {
934934
let mut node = 0;
935935

936936
while let Some(current_node) = self.nodes().get(node) {
937937
if current_node.true_child.is_none() && current_node.false_child.is_none() {
938938
// Leaf node reached
939939
let mut probs = vec![0.0; self.classes().len()];
940940
probs[current_node.output] = 1.0;
941-
return probs;
941+
return Ok(probs);
942942
}
943943

944944
let split_feature = current_node.split_feature;
@@ -952,7 +952,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
952952
}
953953

954954
// This should never happen if the tree is properly constructed
955-
vec![0.0; self.classes().len()]
955+
Err(Failed::predict("Nodes iteration did not reach leaf"))
956956
}
957957
}
958958

0 commit comments

Comments
 (0)