@@ -169,3 +169,100 @@ where
169169 self . mapv ( |a| num_traits:: clamp ( a, min. clone ( ) , max. clone ( ) ) )
170170 }
171171}
172+
173+ #[ cfg( feature = "std" ) ]
174+ impl < A , S , D > ArrayBase < S , D >
175+ where
176+ A : Float + ' static ,
177+ S : Data < Elem = A > ,
178+ D : RemoveAxis ,
179+ {
180+ /// Compute the softmax function along the specified axis.
181+ ///
182+ /// The softmax function is defined as:
183+ /// ```text
184+ /// softmax(x_i) = exp(x_i) / sum(exp(x_j) for j in axis)
185+ /// ```
186+ ///
187+ /// This function is usually used in machine learning to normalize the output of a neural network to a probability
188+ /// distribution.
189+ /// ```
190+ /// use ndarray::{array, Axis};
191+ ///
192+ /// let a = array![[1., 2., 3.], [4., 5., 6.0_f32]];
193+ /// let b = a.softmax(Axis(0)).mapv(|x| (x * 100.0).round() / 100.0);
194+ /// assert_eq!(b, array![[0.05, 0.05, 0.05], [0.95, 0.95, 0.95]]);
195+ /// let c = a.softmax(Axis(1)).mapv(|x| (x * 100.0).round() / 100.0);
196+ /// assert_eq!(c, array![[0.09, 0.24, 0.67], [0.09, 0.24, 0.67]]);
197+ /// ```
198+ ///
199+ /// # Arguments
200+ ///
201+ /// * `axis`: The axis along which to compute the softmax function (so every slice along the axis will sum to 1).
202+ pub fn softmax ( & self , axis : Axis ) -> Array < A , D >
203+ {
204+ let mut res = Array :: uninit ( self . raw_dim ( ) ) ;
205+ for ( arr, mut res) in self . lanes ( axis) . into_iter ( ) . zip ( res. lanes_mut ( axis) ) {
206+ let max = arr
207+ . iter ( )
208+ // If we have NaN and the comparison fails, the max can be arbitrary as the sum and the whole result
209+ // will be NaN anyway, so we use an arbitrary ordering.
210+ . max_by ( |a, b| a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
211+ let max = match max {
212+ Some ( max) => * max,
213+ None => continue ,
214+ } ;
215+ let mut sum = A :: zero ( ) ;
216+ for ( i, x) in res. indexed_iter_mut ( ) {
217+ let v = ( arr[ i] - max) . exp ( ) ;
218+ sum = sum + v;
219+ x. write ( v) ;
220+ }
221+ for x in res. iter_mut ( ) {
222+ // Safety: we wrote to every single element of the `res` array in the previous loop.
223+ x. write ( * unsafe { x. assume_init_ref ( ) } / sum) ;
224+ }
225+ }
226+ // Safety: we wrote to every single element of the array.
227+ unsafe { res. assume_init ( ) }
228+ }
229+ }
230+
231+ #[ cfg( test) ]
232+ mod tests
233+ {
234+ #[ cfg( feature = "std" ) ]
235+ #[ test]
236+ fn test_softmax ( )
237+ {
238+ use super :: * ;
239+ use crate :: array;
240+
241+ let a = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6.0_f32 ] ] ;
242+ let b = a. softmax ( Axis ( 0 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
243+ assert_eq ! ( b, array![ [ 0.05 , 0.05 , 0.05 ] , [ 0.95 , 0.95 , 0.95 ] ] ) ;
244+ let c = a. softmax ( Axis ( 1 ) ) . mapv ( |x| ( x * 100.0 ) . round ( ) / 100.0 ) ;
245+ assert_eq ! ( c, array![ [ 0.09 , 0.24 , 0.67 ] , [ 0.09 , 0.24 , 0.67 ] ] ) ;
246+
247+ #[ cfg( feature = "approx" ) ]
248+ {
249+ // examples copied from scipy softmax documentation
250+
251+ use approx:: assert_relative_eq;
252+
253+ let x = array ! [ [ 1. , 0.5 , 0.2 , 3. ] , [ 1. , -1. , 7. , 3. ] , [ 2. , 12. , 13. , 3. ] ] ;
254+
255+ let m = x. softmax ( Axis ( 0 ) ) ;
256+ let y = array ! [ [ 0.211942 , 0.00001013 , 0.00000275 , 0.333333 ] ,
257+ [ 0.211942 , 0.00000226 , 0.00247262 , 0.333333 ] ,
258+ [ 0.576117 , 0.999988 , 0.997525 , 0.333333 ] ] ;
259+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
260+
261+ let m = x. softmax ( Axis ( 1 ) ) ;
262+ let y = array ! [ [ 1.05877e-01 , 6.42177e-02 , 4.75736e-02 , 7.82332e-01 ] ,
263+ [ 2.42746e-03 , 3.28521e-04 , 9.79307e-01 , 1.79366e-02 ] ,
264+ [ 1.22094e-05 , 2.68929e-01 , 7.31025e-01 , 3.31885e-05 ] ] ;
265+ assert_relative_eq ! ( m, y, epsilon = 1e-5 ) ;
266+ }
267+ }
268+ }
0 commit comments