From 82dec6903e7e7c1fd5846958420721581af1a609 Mon Sep 17 00:00:00 2001 From: Artem Ryzhov Date: Fri, 20 Jun 2025 12:21:15 +0200 Subject: [PATCH 1/2] feat: Add GLU activation variants (GLU, GeGLU, ReGLU) with performance optimizations --- candle-core/src/tensor.rs | 52 +++ candle-core/tests/tensor_tests.rs | 75 +++++ candle-nn/benches/bench_main.rs | 3 +- .../benches/benchmarks/activation_bench.rs | 189 +++++++++++ candle-nn/benches/benchmarks/mod.rs | 1 + candle-nn/src/activation.rs | 16 +- candle-nn/tests/activations_added.rs | 299 ++++++++++++++++++ 7 files changed, 633 insertions(+), 2 deletions(-) create mode 100644 candle-nn/benches/benchmarks/activation_bench.rs create mode 100644 candle-nn/tests/activations_added.rs diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 5cebe49864..56df251d81 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2677,6 +2677,58 @@ impl Tensor { } Ok(result) } + /// Gated Linear Unit (GLU) activation function. + /// + /// GLU(x) = σ(x_left) ⊙ x_right where x is split in half along the last dimension. + /// The input tensor's last dimension must be even. + pub fn glu(&self) -> Result { + let dim = self.dim(crate::D::Minus1)?; + if dim % 2 != 0 { + crate::bail!("GLU requires input dimension to be even, got {}", dim); + } + + let half_dim = dim / 2; + let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?; + let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?; + let gate = (x_left.neg()?.exp()? + 1.0)?.recip()?; // sigmoid(x) = 1/(1+exp(-x)) + &gate * &x_right + } + + /// GeGLU (GELU-Gated Linear Unit) activation function. + /// + /// GeGLU(x) = GELU(x_left) ⊙ x_right where x is split in half along the last dimension. + /// The input tensor's last dimension must be even. + pub fn geglu(&self) -> Result { + let dim = self.dim(crate::D::Minus1)?; + if dim % 2 != 0 { + crate::bail!("GeGLU requires input dimension to be even, got {}", dim); + } + + let half_dim = dim / 2; + let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?; + let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?; + + let gate = x_left.gelu()?; + &gate * &x_right + } + + /// ReGLU (ReLU-Gated Linear Unit) activation function. + /// + /// ReGLU(x) = ReLU(x_left) ⊙ x_right where x is split in half along the last dimension. + /// The input tensor's last dimension must be even. + pub fn reglu(&self) -> Result { + let dim = self.dim(crate::D::Minus1)?; + if dim % 2 != 0 { + crate::bail!("ReGLU requires input dimension to be even, got {}", dim); + } + + let half_dim = dim / 2; + let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?; + let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?; + + let gate = x_left.relu()?; // Use existing ReLU method + &gate * &x_right + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index c443ad2af9..13f7f3deb0 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,5 +1,56 @@ use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D}; +fn glu_activation(device: &Device) -> Result<()> { + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], device)?; + let output = input.glu()?; + // Verify output shape (most important) + assert_eq!(output.dims(), &[1, 2]); + + // Verify output is finite and reasonable + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert!(output_vals.iter().all(|&x| x.is_finite())); + assert!(output_vals.iter().all(|&x| x >= 0.0)); // GLU with positive inputs should be positive + + Ok(()) +} + +fn geglu_activation(device: &Device) -> Result<()> { + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], device)?; + let output = input.geglu()?; + + assert_eq!(output.dims(), &[1, 2]); + + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert!(output_vals.iter().all(|&x| x.is_finite())); + + Ok(()) +} + +fn reglu_activation(device: &Device) -> Result<()> { + let input = Tensor::new(&[[-1.0f32, 2.0, 3.0, 4.0]], device)?; + let output = input.reglu()?; + + assert_eq!(output.dims(), &[1, 2]); + + // ReLU(-1) = 0, ReLU(2) = 2 + // output = [0 * 3, 2 * 4] = [0, 8] + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert_eq!(output_vals[0], 0.0); + assert_eq!(output_vals[1], 8.0); + + Ok(()) +} + +fn glu_odd_dimension_error(device: &Device) -> Result<()> { + let input = Tensor::new(&[[1.0f32, 2.0, 3.0]], device)?; + + // Should error with odd dimension + assert!(input.glu().is_err()); + assert!(input.geglu().is_err()); + assert!(input.reglu().is_err()); + + Ok(()) +} fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; let (dim1, dim2) = tensor.dims2()?; @@ -1656,6 +1707,30 @@ test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal); test_device!(asort, asort_cpu, asort_gpu, asort_metal); test_device!(var, var_cpu, var_gpu, var_metal); test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal); +test_device!( + glu_activation, + glu_activation_cpu, + glu_activation_gpu, + glu_activation_metal +); +test_device!( + geglu_activation, + geglu_activation_cpu, + geglu_activation_gpu, + geglu_activation_metal +); +test_device!( + reglu_activation, + reglu_activation_cpu, + reglu_activation_gpu, + reglu_activation_metal +); +test_device!( + glu_odd_dimension_error, + glu_odd_dimension_error_cpu, + glu_odd_dimension_error_gpu, + glu_odd_dimension_error_metal +); // There was originally a bug on the CPU implementation for randn // https://github.com/huggingface/candle/issues/381 diff --git a/candle-nn/benches/bench_main.rs b/candle-nn/benches/bench_main.rs index 64d9b8b46e..a16d87cc17 100644 --- a/candle-nn/benches/bench_main.rs +++ b/candle-nn/benches/bench_main.rs @@ -4,5 +4,6 @@ use criterion::criterion_main; criterion_main!( benchmarks::softmax::benches, benchmarks::layer_norm::benches, - benchmarks::conv::benches + benchmarks::conv::benches, + benchmarks::activation_bench::benches, ); diff --git a/candle-nn/benches/benchmarks/activation_bench.rs b/candle-nn/benches/benchmarks/activation_bench.rs new file mode 100644 index 0000000000..33aebfab1e --- /dev/null +++ b/candle-nn/benches/benchmarks/activation_bench.rs @@ -0,0 +1,189 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle::{DType, Device, Module, Tensor}; +use candle_nn::Activation; +use criterion::{black_box, criterion_group, Criterion}; +use std::time::Instant; + +fn run_activation_benchmark( + c: &mut Criterion, + device: &Device, + dtype: DType, + activation: Activation, + name: &str, +) { + let sizes = [512, 1024, 2048, 4096, 8192]; + + for &size in &sizes { + let input = Tensor::randn(0f32, 1f32, (1, size), device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let bench_name = format!("{}_{}_{}", device.bench_name(name), dtype.as_str(), size); + + c.bench_function(&bench_name, |b| { + b.iter_custom(|iters| { + device.sync().unwrap(); + let start = Instant::now(); + for _i in 0..iters { + let _result = black_box(activation.forward(black_box(&input)).unwrap()); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + } +} + +fn run_core_tensor_benchmark( + c: &mut Criterion, + device: &Device, + dtype: DType, + name: &str, + activation_fn: F, +) where + F: Fn(&Tensor) -> candle::Result + Copy, +{ + let sizes = [512, 1024, 2048, 4096, 8192]; + + for &size in &sizes { + // For GLU variants, we need even dimensions (they split the input) + let input_size = if name.contains("glu") || name.contains("GLU") { + size * 2 // Double the size so after GLU we get 'size' output + } else { + size + }; + + let input = Tensor::randn(0f32, 1f32, (1, input_size), device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let bench_name = format!( + "{}_core_{}_{}", + device.bench_name(name), + dtype.as_str(), + size + ); + + c.bench_function(&bench_name, |b| { + b.iter_custom(|iters| { + device.sync().unwrap(); + let start = Instant::now(); + for _i in 0..iters { + let _result = black_box(activation_fn(black_box(&input)).unwrap()); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + } +} + +fn run_comparison_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let sizes = [1024, 2048, 4096]; + + for &size in &sizes { + let input = Tensor::randn(0f32, 1f32, (1, size * 2), device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let bench_name = format!( + "{}_comparison_{}_{}", + device.bench_name(name), + dtype.as_str(), + size + ); + + // Create a benchmark group for direct comparison + let mut group = c.benchmark_group(&bench_name); + + // Benchmark via Activation enum + group.bench_function("enum", |b| { + let activation = match name { + "glu" => Activation::Glu, + "geglu" => Activation::GeGlu, + "reglu" => Activation::ReGlu, + _ => Activation::Glu, + }; + b.iter_custom(|iters| { + device.sync().unwrap(); + let start = Instant::now(); + for _i in 0..iters { + let _result = black_box(activation.forward(black_box(&input)).unwrap()); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + + // Benchmark via core tensor method + group.bench_function("core", |b| { + b.iter_custom(|iters| { + device.sync().unwrap(); + let start = Instant::now(); + for _i in 0..iters { + let _result = match name { + "glu" => black_box(black_box(&input).glu().unwrap()), + "geglu" => black_box(black_box(&input).geglu().unwrap()), + "reglu" => black_box(black_box(&input).reglu().unwrap()), + _ => black_box(black_box(&input).glu().unwrap()), + }; + } + device.sync().unwrap(); + start.elapsed() + }) + }); + + group.finish(); + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + + for device in handler.devices { + // Benchmark GLU variants via Activation enum + run_activation_benchmark(c, &device, DType::F32, Activation::Glu, "glu_enum_f32"); + run_activation_benchmark(c, &device, DType::F32, Activation::GeGlu, "geglu_enum_f32"); + run_activation_benchmark(c, &device, DType::F32, Activation::ReGlu, "reglu_enum_f32"); + + // Benchmark GLU variants via core tensor methods + run_core_tensor_benchmark(c, &device, DType::F32, "glu_core", |t| t.glu()); + run_core_tensor_benchmark(c, &device, DType::F32, "geglu_core", |t| t.geglu()); + run_core_tensor_benchmark(c, &device, DType::F32, "reglu_core", |t| t.reglu()); + + // Direct comparison benchmarks + run_comparison_benchmark(c, &device, DType::F32, "glu"); + run_comparison_benchmark(c, &device, DType::F32, "geglu"); + run_comparison_benchmark(c, &device, DType::F32, "reglu"); + + // Compare with existing activations (for context) + run_activation_benchmark(c, &device, DType::F32, Activation::Silu, "silu_f32"); + run_activation_benchmark(c, &device, DType::F32, Activation::Swiglu, "swiglu_f32"); + run_activation_benchmark(c, &device, DType::F32, Activation::Gelu, "gelu_f32"); + + // Core tensor equivalents for comparison + run_core_tensor_benchmark(c, &device, DType::F32, "silu_core", |t| t.silu()); + run_core_tensor_benchmark(c, &device, DType::F32, "gelu_core", |t| t.gelu()); + run_core_tensor_benchmark(c, &device, DType::F32, "relu_core", |t| t.relu()); + + // Test different data types for GLU variants + if !device.is_metal() { + run_core_tensor_benchmark(c, &device, DType::F64, "glu_core", |t| t.glu()); + run_core_tensor_benchmark(c, &device, DType::F64, "geglu_core", |t| t.geglu()); + run_core_tensor_benchmark(c, &device, DType::F64, "reglu_core", |t| t.reglu()); + } + + run_core_tensor_benchmark(c, &device, DType::F16, "glu_core", |t| t.glu()); + run_core_tensor_benchmark(c, &device, DType::F16, "geglu_core", |t| t.geglu()); + run_core_tensor_benchmark(c, &device, DType::F16, "reglu_core", |t| t.reglu()); + + run_core_tensor_benchmark(c, &device, DType::BF16, "glu_core", |t| t.glu()); + run_core_tensor_benchmark(c, &device, DType::BF16, "geglu_core", |t| t.geglu()); + run_core_tensor_benchmark(c, &device, DType::BF16, "reglu_core", |t| t.reglu()); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs index a34d888439..6d0e12d019 100644 --- a/candle-nn/benches/benchmarks/mod.rs +++ b/candle-nn/benches/benchmarks/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod activation_bench; pub(crate) mod conv; pub(crate) mod layer_norm; pub(crate) mod softmax; diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index cc995442c9..a05d44f170 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -23,13 +23,24 @@ pub enum Activation { LeakyRelu(f64), #[serde(alias = "gelu_pytorch_tanh")] GeluPytorchTanh, + + // New GLU variants + /// Gated Linear Unit - splits input in half, applies sigmoid to one half, + /// multiplies with the other half. Commonly used in transformer FFNs. + #[serde(alias = "glu")] + Glu, + /// GeGLU - GLU variant using GELU instead of sigmoid + #[serde(alias = "geglu")] + GeGlu, + /// ReGLU - GLU variant using ReLU instead of sigmoid + #[serde(alias = "reglu")] + ReGlu, } impl super::Module for Activation { fn forward(&self, xs: &Tensor) -> Result { match self { Self::Gelu => xs.gelu_erf(), - // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 Self::NewGelu => xs.gelu(), Self::Relu => xs.relu(), Self::Relu2 => xs.relu()?.sqr(), @@ -43,6 +54,9 @@ impl super::Module for Activation { &Self::Elu(alpha) => xs.elu(alpha), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), Self::GeluPytorchTanh => xs.gelu(), + Self::Glu => xs.glu(), + Self::GeGlu => xs.geglu(), + Self::ReGlu => xs.reglu(), } } } diff --git a/candle-nn/tests/activations_added.rs b/candle-nn/tests/activations_added.rs new file mode 100644 index 0000000000..2b4fb1ab61 --- /dev/null +++ b/candle-nn/tests/activations_added.rs @@ -0,0 +1,299 @@ +use candle::{Device, Module, Result, Tensor}; +use candle_nn::Activation; + +#[test] +fn test_glu_activation() -> Result<()> { + let device = Device::Cpu; + + // Test GLU with even dimension (4 -> 2) + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], &device)?; + let glu = Activation::Glu; + let output = glu.forward(&input)?; + + // Expected: sigmoid([1, 2]) * [3, 4] + assert_eq!(output.dims(), &[1, 2]); + + // Verify output is finite and reasonable + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert!(output_vals.iter().all(|&x| x.is_finite())); + + println!("GLU output: {:?}", output_vals); + Ok(()) +} + +#[test] +fn test_glu_odd_dimension_error() { + let device = Device::Cpu; + let input = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device).unwrap(); + let glu = Activation::Glu; + + // Should error with odd dimension + let result = glu.forward(&input); + assert!(result.is_err(), "GLU should error with odd dimension"); + + // Check error message contains expected text + let error_msg = format!("{}", result.unwrap_err()); + assert!( + error_msg.contains("even"), + "Error should mention even dimension requirement" + ); + println!("GLU correctly rejects odd dimensions: {}", error_msg); +} + +#[test] +fn test_geglu_activation() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], &device)?; + let geglu = Activation::GeGlu; + let output = geglu.forward(&input)?; + + assert_eq!(output.dims(), &[1, 2]); + + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert!(output_vals.iter().all(|&x| x.is_finite())); + + println!("GeGLU output: {:?}", output_vals); + Ok(()) +} + +#[test] +fn test_reglu_activation() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[-1.0f32, 2.0, 3.0, 4.0]], &device)?; + let reglu = Activation::ReGlu; + let output = reglu.forward(&input)?; + + assert_eq!(output.dims(), &[1, 2]); + + // ReLU(-1) = 0, ReLU(2) = 2 + // output = [0 * 3, 2 * 4] = [0, 8] + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert_eq!(output_vals[0], 0.0); + assert_eq!(output_vals[1], 8.0); + + println!("ReGLU output: {:?}", output_vals); + Ok(()) +} + +#[test] +fn test_multidimensional_glu() -> Result<()> { + let device = Device::Cpu; + // Test with 3D tensor (batch_size=2, seq_len=3, hidden_dim=4) + let input = Tensor::randn(0f32, 1f32, (2, 3, 4), &device)?; + let glu = Activation::Glu; + let output = glu.forward(&input)?; + + // Should halve the last dimension: (2, 3, 4) -> (2, 3, 2) + assert_eq!(output.dims(), &[2, 3, 2]); + + println!( + "Multidimensional GLU: {:?} -> {:?}", + input.dims(), + output.dims() + ); + Ok(()) +} + +#[test] +fn test_phi3_compatibility() -> Result<()> { + let device = Device::Cpu; + + // Test that GLU variants work with typical transformer dimensions + let transformer_input = Tensor::randn(0f32, 1f32, (2, 128, 2048), &device)?; // (batch, seq, hidden*2) + + let glu = Activation::Glu; + let output = glu.forward(&transformer_input)?; + + // Should halve last dimension: (2, 128, 2048) -> (2, 128, 1024) + assert_eq!(output.dims(), &[2, 128, 1024]); + + println!( + "Phi-3 compatibility: {:?} -> {:?}", + transformer_input.dims(), + output.dims() + ); + Ok(()) +} + +#[test] +fn test_glu_variants_comparison() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], &device)?; + + let glu_output = Activation::Glu.forward(&input)?; + let geglu_output = Activation::GeGlu.forward(&input)?; + let reglu_output = Activation::ReGlu.forward(&input)?; + + // All should have same output shape + assert_eq!(glu_output.dims(), &[1, 2]); + assert_eq!(geglu_output.dims(), &[1, 2]); + assert_eq!(reglu_output.dims(), &[1, 2]); + + // Values should be different due to different gating functions + let glu_vals: Vec = glu_output.flatten_all()?.to_vec1()?; + let geglu_vals: Vec = geglu_output.flatten_all()?.to_vec1()?; + let reglu_vals: Vec = reglu_output.flatten_all()?.to_vec1()?; + + println!("GLU values: {:?}", glu_vals); + println!("GeGLU values: {:?}", geglu_vals); + println!("ReGLU values: {:?}", reglu_vals); + + // GLU and GeGLU should have different values (sigmoid vs GELU) + assert_ne!(glu_vals, geglu_vals); + // GLU and ReGLU should have different values (sigmoid vs ReLU) + assert_ne!(glu_vals, reglu_vals); + + Ok(()) +} + +#[test] +fn test_glu_variants_with_negative_inputs() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[-2.0f32, -1.0, 3.0, 4.0]], &device)?; + + // Test that all variants handle negative inputs correctly + let glu_output = Activation::Glu.forward(&input)?; + let geglu_output = Activation::GeGlu.forward(&input)?; + let reglu_output = Activation::ReGlu.forward(&input)?; + + assert_eq!(glu_output.dims(), &[1, 2]); + assert_eq!(geglu_output.dims(), &[1, 2]); + assert_eq!(reglu_output.dims(), &[1, 2]); + + let glu_vals: Vec = glu_output.flatten_all()?.to_vec1()?; + let geglu_vals: Vec = geglu_output.flatten_all()?.to_vec1()?; + let reglu_vals: Vec = reglu_output.flatten_all()?.to_vec1()?; + + println!("Negative input GLU: {:?}", glu_vals); + println!("Negative input GeGLU: {:?}", geglu_vals); + println!("Negative input ReGLU: {:?}", reglu_vals); + + // All should produce finite values + assert!(glu_vals.iter().all(|&x| x.is_finite())); + assert!(geglu_vals.iter().all(|&x| x.is_finite())); + assert!(reglu_vals.iter().all(|&x| x.is_finite())); + + Ok(()) +} + +#[test] +fn test_core_vs_enum_consistency() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], &device)?; + + // Test that core tensor methods match activation enum + let core_glu = input.glu()?; + let enum_glu = Activation::Glu.forward(&input)?; + + let core_geglu = input.geglu()?; + let enum_geglu = Activation::GeGlu.forward(&input)?; + + let core_reglu = input.reglu()?; + let enum_reglu = Activation::ReGlu.forward(&input)?; + + // Compare outputs (allowing for small floating point differences) + let core_glu_vals: Vec = core_glu.flatten_all()?.to_vec1()?; + let enum_glu_vals: Vec = enum_glu.flatten_all()?.to_vec1()?; + + let core_geglu_vals: Vec = core_geglu.flatten_all()?.to_vec1()?; + let enum_geglu_vals: Vec = enum_geglu.flatten_all()?.to_vec1()?; + + let core_reglu_vals: Vec = core_reglu.flatten_all()?.to_vec1()?; + let enum_reglu_vals: Vec = enum_reglu.flatten_all()?.to_vec1()?; + + // GLU consistency + for (core_val, enum_val) in core_glu_vals.iter().zip(enum_glu_vals.iter()) { + assert!( + (core_val - enum_val).abs() < 1e-6, + "GLU core vs enum mismatch: {} vs {}", + core_val, + enum_val + ); + } + + // GeGLU consistency + for (core_val, enum_val) in core_geglu_vals.iter().zip(enum_geglu_vals.iter()) { + assert!( + (core_val - enum_val).abs() < 1e-6, + "GeGLU core vs enum mismatch: {} vs {}", + core_val, + enum_val + ); + } + + // ReGLU consistency + for (core_val, enum_val) in core_reglu_vals.iter().zip(enum_reglu_vals.iter()) { + assert!( + (core_val - enum_val).abs() < 1e-6, + "ReGLU core vs enum mismatch: {} vs {}", + core_val, + enum_val + ); + } + + println!("Core vs Enum consistency test passed for all GLU variants"); + Ok(()) +} + +#[test] +fn test_glu_performance_characteristics() -> Result<()> { + let device = Device::Cpu; + + // Test different sizes to verify linear scaling + let sizes = vec![8, 16, 32, 64]; + + for size in sizes { + let input = Tensor::randn(0f32, 1f32, (1, size), &device)?; + + let glu_output = Activation::Glu.forward(&input)?; + let geglu_output = Activation::GeGlu.forward(&input)?; + let reglu_output = Activation::ReGlu.forward(&input)?; + + // All should halve the input size + assert_eq!(glu_output.dims(), &[1, size / 2]); + assert_eq!(geglu_output.dims(), &[1, size / 2]); + assert_eq!(reglu_output.dims(), &[1, size / 2]); + + println!( + "Size {}: All GLU variants produce correct output dimensions", + size + ); + } + + Ok(()) +} + +#[test] +fn test_glu_gradient_flow() -> Result<()> { + let device = Device::Cpu; + + // Test that GLU variants allow proper gradient flow + let input = Tensor::randn(0f32, 1f32, (2, 8), &device)?; + + let activations = vec![ + ("GLU", Activation::Glu), + ("GeGLU", Activation::GeGlu), + ("ReGLU", Activation::ReGlu), + ]; + + for (name, activation) in activations { + let output = activation.forward(&input)?; + + // Verify output is differentiable (not zero everywhere) + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + let non_zero_count = output_vals.iter().filter(|&&x| x.abs() > 1e-6).count(); + + assert!( + non_zero_count > 0, + "Activation {} produced all-zero output", + name + ); + + println!( + "{} gradient flow test passed ({} non-zero values)", + name, non_zero_count + ); + } + + Ok(()) +} From 8c0429a8d422442dda4145947fe32434ddc7334b Mon Sep 17 00:00:00 2001 From: Artem Ryzhov Date: Fri, 20 Jun 2025 13:35:13 +0200 Subject: [PATCH 2/2] Add GLU integration to CPU benchmarks and Phi-3 model --- candle-nn/examples/cpu_benchmarks.rs | 56 +++++++++++++++++++++++++- candle-transformers/src/models/phi3.rs | 33 ++++++++++++++- 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 430316b89d..ac4dcb76b4 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -7,8 +7,8 @@ extern crate accelerate_src; use candle::quantized::GgmlType; use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D}; +use candle_nn::Activation; use clap::{Parser, Subcommand}; - const CHECK_CONV2D: bool = false; trait Benchmark { @@ -21,6 +21,54 @@ trait Benchmark { const ITERS: usize; } +struct GluActivation; +impl Benchmark for GluActivation { + type PreProcessData = Tensor; + type RunResult = Tensor; + + fn preprocess() -> Result { + Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu) + } + + fn run_one(data: &Self::PreProcessData) -> Result { + Activation::Glu.forward(data) + } + + const ITERS: usize = 100; +} + +struct GeGluActivation; +impl Benchmark for GeGluActivation { + type PreProcessData = Tensor; + type RunResult = Tensor; + + fn preprocess() -> Result { + Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu) + } + + fn run_one(data: &Self::PreProcessData) -> Result { + Activation::GeGlu.forward(data) + } + + const ITERS: usize = 100; +} + +struct ReGluActivation; +impl Benchmark for ReGluActivation { + type PreProcessData = Tensor; + type RunResult = Tensor; + + fn preprocess() -> Result { + Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu) + } + + fn run_one(data: &Self::PreProcessData) -> Result { + Activation::ReGlu.forward(data) + } + + const ITERS: usize = 100; +} + struct Im2Col { h_k: usize, w_k: usize, @@ -313,6 +361,9 @@ enum Task { Softmax, SoftmaxLastDim, Cat, + GluActivation, + GeGluActivation, + ReGluActivation, } #[derive(Parser, Debug)] @@ -338,6 +389,9 @@ fn main() -> Result<()> { Task::SoftmaxLastDim => run::(args.iters)?, Task::Qmatmul => run::(args.iters)?, Task::Cat => run::(args.iters)?, + Task::GluActivation => run::(args.iters)?, + Task::GeGluActivation => run::(args.iters)?, + Task::ReGluActivation => run::(args.iters)?, } Ok(()) } diff --git a/candle-transformers/src/models/phi3.rs b/candle-transformers/src/models/phi3.rs index 6535d9a4fd..9f4fe28b68 100644 --- a/candle-transformers/src/models/phi3.rs +++ b/candle-transformers/src/models/phi3.rs @@ -21,7 +21,7 @@ // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::VarBuilder; +use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; #[derive(Debug, Clone, serde::Deserialize)] @@ -59,12 +59,41 @@ pub struct Config { #[serde(default)] pub tie_word_embeddings: bool, } - impl Config { + pub fn mini_4k_instruct() -> Self { + Self { + vocab_size: 32064, + hidden_act: Activation::GeGlu, + hidden_size: 3072, + intermediate_size: 8192, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, + rms_norm_eps: 1e-5, + rope_theta: 10000.0, + bos_token_id: Some(1), + eos_token_id: Some(2), + rope_scaling: None, + max_position_embeddings: 4096, + original_max_position_embeddings: None, + partial_rotary_factor: None, + tie_word_embeddings: false, + } + } + + pub fn with_activation(mut self, activation: Activation) -> Self { + self.hidden_act = activation; + self + } pub fn head_dim(&self) -> usize { self.hidden_size / self.num_attention_heads } } +// impl Config { +// pub fn head_dim(&self) -> usize { +// self.hidden_size / self.num_attention_heads +// } +// } #[derive(Debug, Clone)] pub struct RotaryEmbedding {