Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2677,6 +2677,58 @@ impl Tensor {
}
Ok(result)
}
/// Gated Linear Unit (GLU) activation function.
///
/// GLU(x) = σ(x_left) ⊙ x_right where x is split in half along the last dimension.
/// The input tensor's last dimension must be even.
pub fn glu(&self) -> Result<Self> {
let dim = self.dim(crate::D::Minus1)?;
if dim % 2 != 0 {
crate::bail!("GLU requires input dimension to be even, got {}", dim);
}

let half_dim = dim / 2;
let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?;
let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?;
let gate = (x_left.neg()?.exp()? + 1.0)?.recip()?; // sigmoid(x) = 1/(1+exp(-x))
&gate * &x_right
}

/// GeGLU (GELU-Gated Linear Unit) activation function.
///
/// GeGLU(x) = GELU(x_left) ⊙ x_right where x is split in half along the last dimension.
/// The input tensor's last dimension must be even.
pub fn geglu(&self) -> Result<Self> {
let dim = self.dim(crate::D::Minus1)?;
if dim % 2 != 0 {
crate::bail!("GeGLU requires input dimension to be even, got {}", dim);
}

let half_dim = dim / 2;
let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?;
let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?;

let gate = x_left.gelu()?;
&gate * &x_right
}

/// ReGLU (ReLU-Gated Linear Unit) activation function.
///
/// ReGLU(x) = ReLU(x_left) ⊙ x_right where x is split in half along the last dimension.
/// The input tensor's last dimension must be even.
pub fn reglu(&self) -> Result<Self> {
let dim = self.dim(crate::D::Minus1)?;
if dim % 2 != 0 {
crate::bail!("ReGLU requires input dimension to be even, got {}", dim);
}

let half_dim = dim / 2;
let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?;
let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?;

let gate = x_left.relu()?; // Use existing ReLU method
&gate * &x_right
}
}

macro_rules! bin_trait {
Expand Down
75 changes: 75 additions & 0 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,56 @@
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
fn glu_activation(device: &Device) -> Result<()> {
let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], device)?;
let output = input.glu()?;

// Verify output shape (most important)
assert_eq!(output.dims(), &[1, 2]);

// Verify output is finite and reasonable
let output_vals: Vec<f32> = output.flatten_all()?.to_vec1()?;
assert!(output_vals.iter().all(|&x| x.is_finite()));
assert!(output_vals.iter().all(|&x| x >= 0.0)); // GLU with positive inputs should be positive

Ok(())
}

fn geglu_activation(device: &Device) -> Result<()> {
let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], device)?;
let output = input.geglu()?;

assert_eq!(output.dims(), &[1, 2]);

let output_vals: Vec<f32> = output.flatten_all()?.to_vec1()?;
assert!(output_vals.iter().all(|&x| x.is_finite()));

Ok(())
}

fn reglu_activation(device: &Device) -> Result<()> {
let input = Tensor::new(&[[-1.0f32, 2.0, 3.0, 4.0]], device)?;
let output = input.reglu()?;

assert_eq!(output.dims(), &[1, 2]);

// ReLU(-1) = 0, ReLU(2) = 2
// output = [0 * 3, 2 * 4] = [0, 8]
let output_vals: Vec<f32> = output.flatten_all()?.to_vec1()?;
assert_eq!(output_vals[0], 0.0);
assert_eq!(output_vals[1], 8.0);

Ok(())
}

fn glu_odd_dimension_error(device: &Device) -> Result<()> {
let input = Tensor::new(&[[1.0f32, 2.0, 3.0]], device)?;

// Should error with odd dimension
assert!(input.glu().is_err());
assert!(input.geglu().is_err());
assert!(input.reglu().is_err());

Ok(())
}
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
let (dim1, dim2) = tensor.dims2()?;
Expand Down Expand Up @@ -1656,6 +1707,30 @@ test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
test_device!(
glu_activation,
glu_activation_cpu,
glu_activation_gpu,
glu_activation_metal
);
test_device!(
geglu_activation,
geglu_activation_cpu,
geglu_activation_gpu,
geglu_activation_metal
);
test_device!(
reglu_activation,
reglu_activation_cpu,
reglu_activation_gpu,
reglu_activation_metal
);
test_device!(
glu_odd_dimension_error,
glu_odd_dimension_error_cpu,
glu_odd_dimension_error_gpu,
glu_odd_dimension_error_metal
);

// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381
Expand Down
3 changes: 2 additions & 1 deletion candle-nn/benches/bench_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ use criterion::criterion_main;
criterion_main!(
benchmarks::softmax::benches,
benchmarks::layer_norm::benches,
benchmarks::conv::benches
benchmarks::conv::benches,
benchmarks::activation_bench::benches,
);
189 changes: 189 additions & 0 deletions candle-nn/benches/benchmarks/activation_bench.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle::{DType, Device, Module, Tensor};
use candle_nn::Activation;
use criterion::{black_box, criterion_group, Criterion};
use std::time::Instant;

fn run_activation_benchmark(
c: &mut Criterion,
device: &Device,
dtype: DType,
activation: Activation,
name: &str,
) {
let sizes = [512, 1024, 2048, 4096, 8192];

for &size in &sizes {
let input = Tensor::randn(0f32, 1f32, (1, size), device)
.unwrap()
.to_dtype(dtype)
.unwrap();

let bench_name = format!("{}_{}_{}", device.bench_name(name), dtype.as_str(), size);

c.bench_function(&bench_name, |b| {
b.iter_custom(|iters| {
device.sync().unwrap();
let start = Instant::now();
for _i in 0..iters {
let _result = black_box(activation.forward(black_box(&input)).unwrap());
}
device.sync().unwrap();
start.elapsed()
})
});
}
}

fn run_core_tensor_benchmark<F>(
c: &mut Criterion,
device: &Device,
dtype: DType,
name: &str,
activation_fn: F,
) where
F: Fn(&Tensor) -> candle::Result<Tensor> + Copy,
{
let sizes = [512, 1024, 2048, 4096, 8192];

for &size in &sizes {
// For GLU variants, we need even dimensions (they split the input)
let input_size = if name.contains("glu") || name.contains("GLU") {
size * 2 // Double the size so after GLU we get 'size' output
} else {
size
};

let input = Tensor::randn(0f32, 1f32, (1, input_size), device)
.unwrap()
.to_dtype(dtype)
.unwrap();

let bench_name = format!(
"{}_core_{}_{}",
device.bench_name(name),
dtype.as_str(),
size
);

c.bench_function(&bench_name, |b| {
b.iter_custom(|iters| {
device.sync().unwrap();
let start = Instant::now();
for _i in 0..iters {
let _result = black_box(activation_fn(black_box(&input)).unwrap());
}
device.sync().unwrap();
start.elapsed()
})
});
}
}

fn run_comparison_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let sizes = [1024, 2048, 4096];

for &size in &sizes {
let input = Tensor::randn(0f32, 1f32, (1, size * 2), device)
.unwrap()
.to_dtype(dtype)
.unwrap();

let bench_name = format!(
"{}_comparison_{}_{}",
device.bench_name(name),
dtype.as_str(),
size
);

// Create a benchmark group for direct comparison
let mut group = c.benchmark_group(&bench_name);

// Benchmark via Activation enum
group.bench_function("enum", |b| {
let activation = match name {
"glu" => Activation::Glu,
"geglu" => Activation::GeGlu,
"reglu" => Activation::ReGlu,
_ => Activation::Glu,
};
b.iter_custom(|iters| {
device.sync().unwrap();
let start = Instant::now();
for _i in 0..iters {
let _result = black_box(activation.forward(black_box(&input)).unwrap());
}
device.sync().unwrap();
start.elapsed()
})
});

// Benchmark via core tensor method
group.bench_function("core", |b| {
b.iter_custom(|iters| {
device.sync().unwrap();
let start = Instant::now();
for _i in 0..iters {
let _result = match name {
"glu" => black_box(black_box(&input).glu().unwrap()),
"geglu" => black_box(black_box(&input).geglu().unwrap()),
"reglu" => black_box(black_box(&input).reglu().unwrap()),
_ => black_box(black_box(&input).glu().unwrap()),
};
}
device.sync().unwrap();
start.elapsed()
})
});

group.finish();
}
}

fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();

for device in handler.devices {
// Benchmark GLU variants via Activation enum
run_activation_benchmark(c, &device, DType::F32, Activation::Glu, "glu_enum_f32");
run_activation_benchmark(c, &device, DType::F32, Activation::GeGlu, "geglu_enum_f32");
run_activation_benchmark(c, &device, DType::F32, Activation::ReGlu, "reglu_enum_f32");

// Benchmark GLU variants via core tensor methods
run_core_tensor_benchmark(c, &device, DType::F32, "glu_core", |t| t.glu());
run_core_tensor_benchmark(c, &device, DType::F32, "geglu_core", |t| t.geglu());
run_core_tensor_benchmark(c, &device, DType::F32, "reglu_core", |t| t.reglu());

// Direct comparison benchmarks
run_comparison_benchmark(c, &device, DType::F32, "glu");
run_comparison_benchmark(c, &device, DType::F32, "geglu");
run_comparison_benchmark(c, &device, DType::F32, "reglu");

// Compare with existing activations (for context)
run_activation_benchmark(c, &device, DType::F32, Activation::Silu, "silu_f32");
run_activation_benchmark(c, &device, DType::F32, Activation::Swiglu, "swiglu_f32");
run_activation_benchmark(c, &device, DType::F32, Activation::Gelu, "gelu_f32");

// Core tensor equivalents for comparison
run_core_tensor_benchmark(c, &device, DType::F32, "silu_core", |t| t.silu());
run_core_tensor_benchmark(c, &device, DType::F32, "gelu_core", |t| t.gelu());
run_core_tensor_benchmark(c, &device, DType::F32, "relu_core", |t| t.relu());

// Test different data types for GLU variants
if !device.is_metal() {
run_core_tensor_benchmark(c, &device, DType::F64, "glu_core", |t| t.glu());
run_core_tensor_benchmark(c, &device, DType::F64, "geglu_core", |t| t.geglu());
run_core_tensor_benchmark(c, &device, DType::F64, "reglu_core", |t| t.reglu());
}

run_core_tensor_benchmark(c, &device, DType::F16, "glu_core", |t| t.glu());
run_core_tensor_benchmark(c, &device, DType::F16, "geglu_core", |t| t.geglu());
run_core_tensor_benchmark(c, &device, DType::F16, "reglu_core", |t| t.reglu());

run_core_tensor_benchmark(c, &device, DType::BF16, "glu_core", |t| t.glu());
run_core_tensor_benchmark(c, &device, DType::BF16, "geglu_core", |t| t.geglu());
run_core_tensor_benchmark(c, &device, DType::BF16, "reglu_core", |t| t.reglu());
}
}

criterion_group!(benches, criterion_benchmark);
1 change: 1 addition & 0 deletions candle-nn/benches/benchmarks/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub(crate) mod activation_bench;
pub(crate) mod conv;
pub(crate) mod layer_norm;
pub(crate) mod softmax;
Expand Down
Loading