|
35 | 35 | //! |
36 | 36 | //! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script> |
37 | 37 | //! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> |
38 | | -use crate::error::Failed; |
| 38 | +use crate::error::{Failed, FailedError}; |
39 | 39 | use crate::linalg::basic::arrays::{Array1, Array2, ArrayView1}; |
40 | 40 | use crate::numbers::basenum::Number; |
41 | 41 | #[cfg(feature = "serde")] |
@@ -93,24 +93,36 @@ impl<TX: Number, TY: Number, X: Array2<TX>, Y: Array1<TY>, D: NBDistribution<TX, |
93 | 93 | pub fn predict(&self, x: &X) -> Result<Y, Failed> { |
94 | 94 | let y_classes = self.distribution.classes(); |
95 | 95 | let (rows, _) = x.shape(); |
| 96 | + let mut log_likehood_is_nan = false; |
96 | 97 | let predictions = (0..rows) |
97 | 98 | .map(|row_index| { |
98 | 99 | let row = x.get_row(row_index); |
99 | 100 | let (prediction, _probability) = y_classes |
100 | 101 | .iter() |
101 | 102 | .enumerate() |
102 | 103 | .map(|(class_index, class)| { |
| 104 | + let mut log_likelihood = self.distribution.log_likelihood(class_index, &row); |
| 105 | + if log_likelihood.is_nan() { |
| 106 | + log_likelihood = 0f64; |
| 107 | + log_likehood_is_nan = true; |
| 108 | + } |
103 | 109 | ( |
104 | 110 | class, |
105 | | - self.distribution.log_likelihood(class_index, &row) |
| 111 | + log_likelihood |
106 | 112 | + self.distribution.prior(class_index).ln(), |
107 | 113 | ) |
108 | | - }) |
| 114 | + }) |
109 | 115 | .max_by(|(_, p1), (_, p2)| p1.partial_cmp(p2).unwrap()) |
110 | 116 | .unwrap(); |
111 | 117 | *prediction |
112 | 118 | }) |
113 | 119 | .collect::<Vec<TY>>(); |
| 120 | + if log_likehood_is_nan { |
| 121 | + return Err(Failed::because( |
| 122 | + FailedError::SolutionFailed, |
| 123 | + "log_likelihood for distribution of one of the rows is NaN", |
| 124 | + )); |
| 125 | + } |
114 | 126 | let y_hat = Y::from_vec_slice(&predictions); |
115 | 127 | Ok(y_hat) |
116 | 128 | } |
|
0 commit comments