Skip to content

Commit 8881439

Browse files
committed
Fix #245: return error for NaN in naive bayes
1 parent 83dcf9a commit 8881439

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
clippy::upper_case_acronyms
77
)]
88
#![warn(missing_docs)]
9-
#![warn(rustdoc::missing_doc_code_examples)]
109

1110
//! # smartcore
1211
//!

src/naive_bayes/gaussian.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,27 @@ mod tests {
425425
);
426426
}
427427

428+
#[test]
429+
fn run_gaussian_naive_bayes_with_few_samples() {
430+
let x = DenseMatrix::<f64>::from_2d_array(&[
431+
&[-1., -1.],
432+
&[-2., -1.],
433+
&[-3., -2.],
434+
&[1., 1.],
435+
]);
436+
let y: Vec<u32> = vec![1, 1, 1, 2];
437+
438+
let gnb = GaussianNB::fit(&x, &y, Default::default());
439+
440+
match gnb.unwrap().predict(&x) {
441+
Ok(_) => assert!(false, "test should return Failed"),
442+
Err(err) => {
443+
assert!(err.to_string() == "Can't find solution: log_likelihood for distribution of one of the rows is NaN");
444+
assert!(true)
445+
},
446+
}
447+
}
448+
428449
#[cfg_attr(
429450
all(target_arch = "wasm32", not(target_os = "wasi")),
430451
wasm_bindgen_test::wasm_bindgen_test

src/naive_bayes/mod.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
//!
3636
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
3737
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
38-
use crate::error::Failed;
38+
use crate::error::{Failed, FailedError};
3939
use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1};
4040
use crate::numbers::basenum::Number;
4141
#[cfg(feature = "serde")]
@@ -93,24 +93,36 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX,
9393
pub fn predict(&self, x: &X) -> Result<Y, Failed> {
9494
let y_classes = self.distribution.classes();
9595
let (rows, _) = x.shape();
96+
let mut log_likehood_is_nan = false;
9697
let predictions = (0..rows)
9798
.map(|row_index| {
9899
let row = x.get_row(row_index);
99100
let (prediction, _probability) = y_classes
100101
.iter()
101102
.enumerate()
102103
.map(|(class_index, class)| {
104+
let mut log_likelihood = self.distribution.log_likelihood(class_index, &row);
105+
if log_likelihood.is_nan() {
106+
log_likelihood = 0f64;
107+
log_likehood_is_nan = true;
108+
}
103109
(
104110
class,
105-
self.distribution.log_likelihood(class_index, &row)
111+
log_likelihood
106112
+ self.distribution.prior(class_index).ln(),
107113
)
108-
})
114+
})
109115
.max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap())
110116
.unwrap();
111117
*prediction
112118
})
113119
.collect::<Vec<TY>>();
120+
if log_likehood_is_nan {
121+
return Err(Failed::because(
122+
FailedError::SolutionFailed,
123+
"log_likelihood for distribution of one of the rows is NaN",
124+
));
125+
}
114126
let y_hat = Y::from_vec_slice(&predictions);
115127
Ok(y_hat)
116128
}

0 commit comments

Comments
 (0)