@@ -40,7 +40,7 @@ use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
4040use crate :: numbers:: basenum:: Number ;
4141#[ cfg( feature = "serde" ) ]
4242use serde:: { Deserialize , Serialize } ;
43- use std:: marker:: PhantomData ;
43+ use std:: { cmp :: Ordering , marker:: PhantomData } ;
4444
4545/// Distribution used in the Naive Bayes classifier.
4646pub ( crate ) trait NBDistribution < X : Number , Y : Number > : Clone {
@@ -92,11 +92,9 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
9292 /// Returns a vector of size N with class estimates.
9393 pub fn predict ( & self , x : & X ) -> Result < Y , Failed > {
9494 let y_classes = self . distribution . classes ( ) ;
95- let ( rows, _) = x. shape ( ) ;
96- let predictions = ( 0 ..rows)
97- . map ( |row_index| {
98- let row = x. get_row ( row_index) ;
99- let ( prediction, _probability) = y_classes
95+ let predictions = x. row_iter ( )
96+ . map ( |row| {
97+ y_classes
10098 . iter ( )
10199 . enumerate ( )
102100 . map ( |( class_index, class) | {
@@ -106,11 +104,28 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
106104 + self . distribution . prior ( class_index) . ln ( ) ,
107105 )
108106 } )
109- . max_by ( |( _, p1) , ( _, p2) | p1. partial_cmp ( p2) . unwrap ( ) )
110- . unwrap ( ) ;
111- * prediction
107+ // For some reason, the max_by method cannot use NaNs for finding the maximum value, it panics.
108+ // NaN must be considered as minimum values,
109+ // therefore it's like NaNs would not be considered for choosing the maximum value.
110+ // So we need to handle this case for avoiding panicking by using `Option::unwrap`.
111+ . max_by ( |( _, p1) , ( _, p2) | {
112+ match p1. partial_cmp ( p2) {
113+ Some ( ordering) => ordering,
114+ None => {
115+ if p1. is_nan ( ) {
116+ Ordering :: Less
117+ } else if p2. is_nan ( ) {
118+ Ordering :: Greater
119+ } else {
120+ Ordering :: Equal
121+ }
122+ }
123+ }
124+ } )
125+ . map ( |( prediction, _probability) | * prediction)
126+ . ok_or_else ( || Failed :: predict ( "Failed to predict, there is no result" ) )
112127 } )
113- . collect :: < Vec < TY > > ( ) ;
128+ . collect :: < Result < Vec < TY > , Failed > > ( ) ? ;
114129 let y_hat = Y :: from_vec_slice ( & predictions) ;
115130 Ok ( y_hat)
116131 }
@@ -119,3 +134,56 @@ pub mod bernoulli;
119134pub mod categorical;
120135pub mod gaussian;
121136pub mod multinomial;
137+
138+ #[ cfg( test) ]
139+ mod tests {
140+ use super :: * ;
141+ use crate :: linalg:: basic:: matrix:: DenseMatrix ;
142+ use num_traits:: float:: Float ;
143+ use crate :: linalg:: basic:: arrays:: Array ;
144+
145+ type Model < ' d > = BaseNaiveBayes < i32 , i32 , DenseMatrix < i32 > , Vec < i32 > , TestDistribution < ' d > > ;
146+
147+ #[ derive( Debug , PartialEq , Clone ) ]
148+ struct TestDistribution < ' d > ( & ' d Vec < i32 > ) ;
149+
150+ impl < ' d > NBDistribution < i32 , i32 > for TestDistribution < ' d > {
151+ fn prior ( & self , _class_index : usize ) -> f64 {
152+ 1.
153+ }
154+
155+ fn log_likelihood < ' a > ( & ' a self , class_index : usize , _j : & ' a Box < dyn ArrayView1 < i32 > + ' a > ) -> f64 {
156+ match self . 0 . get ( class_index) {
157+ & v @ 2 | & v @ 10 | & v @ 20 => v as f64 ,
158+ _ => f64:: nan ( ) ,
159+ }
160+ }
161+
162+ fn classes ( & self ) -> & Vec < i32 > {
163+ & self . 0
164+ }
165+ }
166+
167+ #[ test]
168+ fn test_predict ( ) {
169+ let matrix = DenseMatrix :: from_2d_array ( & [ & [ 1 , 2 , 3 ] , & [ 4 , 5 , 6 ] , & [ 7 , 8 , 9 ] ] ) ;
170+
171+ let val = vec ! [ ] ;
172+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
173+ Ok ( _) => panic ! ( "Should return error in case of empty classes" ) ,
174+ Err ( err) => assert_eq ! ( err. to_string( ) , "Predict failed: Failed to predict, there is no result" ) ,
175+ }
176+
177+ let val = vec ! [ 1 , 2 , 3 ] ;
178+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
179+ Ok ( r) => assert_eq ! ( r, vec![ 2 , 2 , 2 ] ) ,
180+ Err ( _) => panic ! ( "Should success in normal case with NaNs" ) ,
181+ }
182+
183+ let val = vec ! [ 20 , 2 , 10 ] ;
184+ match Model :: fit ( TestDistribution ( & val) ) . unwrap ( ) . predict ( & matrix) {
185+ Ok ( r) => assert_eq ! ( r, vec![ 20 , 20 , 20 ] ) ,
186+ Err ( _) => panic ! ( "Should success in normal case without NaNs" ) ,
187+ }
188+ }
189+ }
0 commit comments