@@ -169,3 +169,94 @@ where
169169 self . mapv ( |a| num_traits:: clamp ( a, min. clone ( ) , max. clone ( ) ) )
170170 }
171171}
172+
173+ #[ cfg( feature = "std" ) ]
174+ impl < A , S , D > ArrayBase < S , D >
175+ where
176+ A : Float + ' static ,
177+ S : Data < Elem = A > ,
178+ D : RemoveAxis ,
179+ {
180+ /// Compute the softmax function along the specified axis.
181+ ///
182+ /// The softmax function is defined as:
183+ /// ```text
184+ /// softmax(x_i) = exp(x_i) / sum(exp(x_j) for j in axis)
185+ /// ```
186+ ///
187+ /// This function is usually used in machine learning to normalize the output of a neural network to a probability
188+ /// distribution.
189+ /// ```
190+ /// use ndarray::{array, Axis};
191+ ///
192+ /// let a = array![[1., 2., 3.], [4., 5., 6.]];
193+ /// let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
194+ /// assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
195+ /// let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
196+ /// assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
197+ /// ```
198+ ///
199+ /// # Arguments
200+ ///
201+ /// * `axis`: The axis along which to compute the softmax function (so every slice along the axis will sum to 1).
202+ pub fn softmax ( & self , axis : Axis ) -> Array < A , D >
203+ {
204+ let mut res = Array :: uninit ( self . raw_dim ( ) ) ;
205+ for ( arr, mut res) in self . lanes ( axis) . into_iter ( ) . zip ( res. lanes_mut ( axis) ) {
206+ let max = arr
207+ . iter ( )
208+ // If we have NaN and the comparison fails, the max can be arbitrary as the sum and the whole result
209+ // will be NaN anyway, so we use an arbitrary ordering.
210+ . max_by ( |a, b| a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
211+ let max = match max {
212+ Some ( max) => * max,
213+ None => continue ,
214+ } ;
215+ let sum = arr. fold ( A :: zero ( ) , |sum, x| sum + ( * x - max) . exp ( ) ) ;
216+ for ( i, x) in res. indexed_iter_mut ( ) {
217+ x. write ( ( arr[ i] - max) . exp ( ) / sum) ;
218+ }
219+ }
220+ // Safety: we wrote to every single element of the array.
221+ unsafe { res. assume_init ( ) }
222+ }
223+ }
224+
225+ #[ cfg( test) ]
226+ mod tests
227+ {
228+ #[ cfg( feature = "std" ) ]
229+ #[ test]
230+ fn test_softmax ( )
231+ {
232+ use super :: * ;
233+ use crate :: array;
234+
235+ let a = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
236+ let b = a. softmax ( Axis ( 0 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
237+ assert_eq ! ( b, array![ [ 0.05 , 0.05 , 0.05 ] , [ 0.95 , 0.95 , 0.95 ] ] ) ;
238+ let c = a. softmax ( Axis ( 1 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
239+ assert_eq ! ( c, array![ [ 0.09 , 0.24 , 0.67 ] , [ 0.09 , 0.24 , 0.67 ] ] ) ;
240+
241+ #[ cfg( feature = "approx" ) ]
242+ {
243+ // examples copied from scipy softmax documentation
244+
245+ use approx:: assert_relative_eq;
246+
247+ let x = array ! [ [ 1. , 0.5 , 0.2 , 3. ] , [ 1. , -1. , 7. , 3. ] , [ 2. , 12. , 13. , 3. ] ] ;
248+
249+ let m = x. softmax ( Axis ( 0 ) ) ;
250+ let y = array ! [ [ 0.211942 , 0.00001013 , 0.00000275 , 0.333333 ] ,
251+ [ 0.211942 , 0.00000226 , 0.00247262 , 0.333333 ] ,
252+ [ 0.576117 , 0.999988 , 0.997525 , 0.333333 ] ] ;
253+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
254+
255+ let m = x. softmax ( Axis ( 1 ) ) ;
256+ let y = array ! [ [ 1.05877e-01 , 6.42177e-02 , 4.75736e-02 , 7.82332e-01 ] ,
257+ [ 2.42746e-03 , 3.28521e-04 , 9.79307e-01 , 1.79366e-02 ] ,
258+ [ 1.22094e-05 , 2.68929e-01 , 7.31025e-01 , 3.31885e-05 ] ] ;
259+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
260+ }
261+ }
262+ }
0 commit comments