@@ -903,14 +903,14 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
903903 ///
904904 /// # Errors
905905 ///
906- /// Returns an error if the prediction process fails.
906+ /// Returns an error if at least one row prediction process fails.
907907 pub fn predict_proba ( & self , x : & X ) -> Result < DenseMatrix < f64 > , Failed > {
908908 let ( n_samples, _) = x. shape ( ) ;
909909 let n_classes = self . classes ( ) . len ( ) ;
910910 let mut result = DenseMatrix :: < f64 > :: zeros ( n_samples, n_classes) ;
911911
912912 for i in 0 ..n_samples {
913- let probs = self . predict_proba_for_row ( x, i) ;
913+ let probs = self . predict_proba_for_row ( x, i) ? ;
914914 for ( j, & prob) in probs. iter ( ) . enumerate ( ) {
915915 result. set ( ( i, j) , prob) ;
916916 }
@@ -930,15 +930,15 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
930930 ///
931931 /// A vector of probabilities, one for each class, representing the probability
932932 /// of the input sample belonging to each class.
933- fn predict_proba_for_row ( & self , x : & X , row : usize ) -> Vec < f64 > {
933+ fn predict_proba_for_row ( & self , x : & X , row : usize ) -> Result < Vec < f64 > , Failed > {
934934 let mut node = 0 ;
935935
936936 while let Some ( current_node) = self . nodes ( ) . get ( node) {
937937 if current_node. true_child . is_none ( ) && current_node. false_child . is_none ( ) {
938938 // Leaf node reached
939939 let mut probs = vec ! [ 0.0 ; self . classes( ) . len( ) ] ;
940940 probs[ current_node. output ] = 1.0 ;
941- return probs;
941+ return Ok ( probs) ;
942942 }
943943
944944 let split_feature = current_node. split_feature ;
@@ -952,7 +952,7 @@ impl<TX: Number + PartialOrd, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
952952 }
953953
954954 // This should never happen if the tree is properly constructed
955- vec ! [ 0.0 ; self . classes ( ) . len ( ) ]
955+ Err ( Failed :: predict ( "Nodes iteration did not reach leaf" ) )
956956 }
957957}
958958
0 commit comments