Skip to content

Commit 1dda703

Browse files
committed
feature: cpu softmax
1 parent 6d8fb10 commit 1dda703

File tree

3 files changed

+59
-13
lines changed

3 files changed

+59
-13
lines changed

crates/ratchet-core/src/cpu/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pub mod gemm;
33
mod norm;
44
pub mod reindex;
55
pub mod rope;
6+
mod softmax;
67
mod unary;
78
mod utils;
89

@@ -21,7 +22,7 @@ pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result<Tensor, OperationError
2122
LazyOp::Binary(b) => b.apply_cpu(dst),
2223
LazyOp::Cast(c) => cpu_cast(c, dst),
2324
LazyOp::Matmul(m) => m.apply_cpu(dst),
24-
LazyOp::Softmax(_s) => todo!(),
25+
LazyOp::Softmax(s) => s.apply_cpu(dst),
2526
LazyOp::RoPE(r) => cpu_rope(r, dst),
2627
LazyOp::Unary(u) => u.apply_cpu(dst),
2728
LazyOp::Reindex(r) => r.apply_cpu(dst),
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use crate::cpu::utils::cpu_store_result;
2+
use crate::{CPUOperation, DType, OperationError, Softmax, Tensor, TensorDType};
3+
use half::{bf16, f16};
4+
use num::Float;
5+
use num_traits::NumAssignOps;
6+
7+
impl CPUOperation for Softmax {
8+
fn apply_cpu(&self, dst: Tensor) -> Result<Tensor, OperationError> {
9+
let Softmax { input, dim } = self;
10+
match input.dt() {
11+
DType::F32 => softmax::<f32>(input, *dim, &dst)?,
12+
DType::F16 => softmax::<f16>(input, *dim, &dst)?,
13+
DType::BF16 => softmax::<bf16>(input, *dim, &dst)?,
14+
_ => todo!(),
15+
}
16+
17+
Ok(dst)
18+
}
19+
}
20+
21+
fn softmax<T>(input: &Tensor, dim: usize, dst: &Tensor) -> Result<(), OperationError>
22+
where
23+
T: TensorDType + Float + NumAssignOps,
24+
{
25+
let src_shape = input.shape();
26+
let mut input = input.to_vec::<T>()?;
27+
let N = src_shape[dim];
28+
input.chunks_mut(N).for_each(|chunk| {
29+
let mut sum = T::zero();
30+
for j in 0..N {
31+
chunk[j] = chunk[j].exp();
32+
sum += chunk[j];
33+
}
34+
for j in 0..N {
35+
chunk[j] /= sum;
36+
}
37+
});
38+
39+
cpu_store_result(dst, &input);
40+
41+
Ok(())
42+
}

crates/ratchet-core/src/ops/softmax.rs

+15-12
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ use crate::{
1313

1414
#[derive(new, Debug, Clone)]
1515
pub struct Softmax {
16-
input: Tensor,
17-
dim: usize,
16+
pub(crate) input: Tensor,
17+
pub(crate) dim: usize,
1818
}
1919

2020
#[derive(Debug, derive_new::new, ShaderType, WgslMetadata)]
@@ -322,8 +322,7 @@ def softmax(a):
322322
run_py_prg(prg.to_string(), &[a], &[], a.dt())
323323
}
324324

325-
fn run_softmax_trial(problem: SoftmaxProblem) {
326-
let device = Device::request_device(DeviceRequest::GPU).unwrap();
325+
fn run_softmax_trial(problem: SoftmaxProblem, device: Device) {
327326
let SoftmaxProblem { B, M, N } = problem;
328327
let a = Tensor::randn::<f32>(shape![B, M, N], Device::CPU);
329328
let ground = ground_truth(&a).unwrap();
@@ -332,8 +331,6 @@ def softmax(a):
332331
let b = a_gpu.softmax(2).unwrap().resolve().unwrap();
333332

334333
let ours = b.to(&Device::CPU).unwrap();
335-
println!("ours = {:?}", ours);
336-
println!("ground = {:?}", ground);
337334
ground.all_close(&ours, 1e-6, 1e-6).unwrap();
338335
}
339336

@@ -347,16 +344,22 @@ def softmax(a):
347344
N: usize,
348345
}
349346

350-
#[proptest(cases = 8)]
351-
fn test_softmax(prob: SoftmaxProblem) {
352-
let SoftmaxProblem { B, M, N } = prob;
353-
println!("B = {}, M = {}, N = {}", B, M, N);
354-
run_softmax_trial(prob);
347+
#[proptest(cases = 18)]
348+
fn test_softmax_gpu(prob: SoftmaxProblem) {
349+
let device = Device::request_device(DeviceRequest::GPU).unwrap();
350+
run_softmax_trial(prob, device);
351+
}
352+
353+
#[proptest(cases = 16)]
354+
fn test_softmax_cpu(prob: SoftmaxProblem) {
355+
let device = Device::request_device(DeviceRequest::CPU).unwrap();
356+
run_softmax_trial(prob, device);
355357
}
356358

357359
#[test]
358360
fn dbg_softmax() {
361+
let device = Device::request_device(DeviceRequest::GPU).unwrap();
359362
let problem = SoftmaxProblem { B: 1, M: 2, N: 128 };
360-
run_softmax_trial(problem);
363+
run_softmax_trial(problem, device);
361364
}
362365
}

0 commit comments

Comments
 (0)