@@ -13,8 +13,8 @@ use crate::{
13
13
14
14
#[ derive( new, Debug , Clone ) ]
15
15
pub struct Softmax {
16
- input : Tensor ,
17
- dim : usize ,
16
+ pub ( crate ) input : Tensor ,
17
+ pub ( crate ) dim : usize ,
18
18
}
19
19
20
20
#[ derive( Debug , derive_new:: new, ShaderType , WgslMetadata ) ]
@@ -322,8 +322,7 @@ def softmax(a):
322
322
run_py_prg ( prg. to_string ( ) , & [ a] , & [ ] , a. dt ( ) )
323
323
}
324
324
325
- fn run_softmax_trial ( problem : SoftmaxProblem ) {
326
- let device = Device :: request_device ( DeviceRequest :: GPU ) . unwrap ( ) ;
325
+ fn run_softmax_trial ( problem : SoftmaxProblem , device : Device ) {
327
326
let SoftmaxProblem { B, M, N } = problem;
328
327
let a = Tensor :: randn :: < f32 > ( shape ! [ B , M , N ] , Device :: CPU ) ;
329
328
let ground = ground_truth ( & a) . unwrap ( ) ;
@@ -332,8 +331,6 @@ def softmax(a):
332
331
let b = a_gpu. softmax ( 2 ) . unwrap ( ) . resolve ( ) . unwrap ( ) ;
333
332
334
333
let ours = b. to ( & Device :: CPU ) . unwrap ( ) ;
335
- println ! ( "ours = {:?}" , ours) ;
336
- println ! ( "ground = {:?}" , ground) ;
337
334
ground. all_close ( & ours, 1e-6 , 1e-6 ) . unwrap ( ) ;
338
335
}
339
336
@@ -347,16 +344,22 @@ def softmax(a):
347
344
N : usize ,
348
345
}
349
346
350
- #[ proptest( cases = 8 ) ]
351
- fn test_softmax ( prob : SoftmaxProblem ) {
352
- let SoftmaxProblem { B, M, N } = prob;
353
- println ! ( "B = {}, M = {}, N = {}" , B , M , N ) ;
354
- run_softmax_trial ( prob) ;
347
+ #[ proptest( cases = 18 ) ]
348
+ fn test_softmax_gpu ( prob : SoftmaxProblem ) {
349
+ let device = Device :: request_device ( DeviceRequest :: GPU ) . unwrap ( ) ;
350
+ run_softmax_trial ( prob, device) ;
351
+ }
352
+
353
+ #[ proptest( cases = 16 ) ]
354
+ fn test_softmax_cpu ( prob : SoftmaxProblem ) {
355
+ let device = Device :: request_device ( DeviceRequest :: CPU ) . unwrap ( ) ;
356
+ run_softmax_trial ( prob, device) ;
355
357
}
356
358
357
359
#[ test]
358
360
fn dbg_softmax ( ) {
361
+ let device = Device :: request_device ( DeviceRequest :: GPU ) . unwrap ( ) ;
359
362
let problem = SoftmaxProblem { B : 1 , M : 2 , N : 128 } ;
360
- run_softmax_trial ( problem) ;
363
+ run_softmax_trial ( problem, device ) ;
361
364
}
362
365
}
0 commit comments