diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index d71630212d..b66a32e24a 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2754,6 +2754,58 @@ impl Tensor { } Ok(result) } + /// Gated Linear Unit (GLU) activation function. + /// + /// GLU(x) = σ(x_left) ⊙ x_right where x is split in half along the last dimension. + /// The input tensor's last dimension must be even. + pub fn glu(&self) -> Result { + let dim = self.dim(crate::D::Minus1)?; + if dim % 2 != 0 { + crate::bail!("GLU requires input dimension to be even, got {}", dim); + } + + let half_dim = dim / 2; + let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?; + let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?; + let gate = (x_left.neg()?.exp()? + 1.0)?.recip()?; // sigmoid(x) = 1/(1+exp(-x)) + &gate * &x_right + } + + /// GeGLU (GELU-Gated Linear Unit) activation function. + /// + /// GeGLU(x) = GELU(x_left) ⊙ x_right where x is split in half along the last dimension. + /// The input tensor's last dimension must be even. + pub fn geglu(&self) -> Result { + let dim = self.dim(crate::D::Minus1)?; + if dim % 2 != 0 { + crate::bail!("GeGLU requires input dimension to be even, got {}", dim); + } + + let half_dim = dim / 2; + let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?; + let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?; + + let gate = x_left.gelu()?; + &gate * &x_right + } + + /// ReGLU (ReLU-Gated Linear Unit) activation function. + /// + /// ReGLU(x) = ReLU(x_left) ⊙ x_right where x is split in half along the last dimension. + /// The input tensor's last dimension must be even. + pub fn reglu(&self) -> Result { + let dim = self.dim(crate::D::Minus1)?; + if dim % 2 != 0 { + crate::bail!("ReGLU requires input dimension to be even, got {}", dim); + } + + let half_dim = dim / 2; + let x_left = self.narrow(crate::D::Minus1, 0, half_dim)?; + let x_right = self.narrow(crate::D::Minus1, half_dim, half_dim)?; + + let gate = x_left.relu()?; // Use existing ReLU method + &gate * &x_right + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index d264cc0bd9..81a9c83409 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,6 +1,57 @@ use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D}; +fn glu_activation(device: &Device) -> Result<()> { + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], device)?; + let output = input.glu()?; use float8::F8E4M3; + // Verify output shape (most important) + assert_eq!(output.dims(), &[1, 2]); + + // Verify output is finite and reasonable + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert!(output_vals.iter().all(|&x| x.is_finite())); + assert!(output_vals.iter().all(|&x| x >= 0.0)); // GLU with positive inputs should be positive + + Ok(()) +} + +fn geglu_activation(device: &Device) -> Result<()> { + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], device)?; + let output = input.geglu()?; + + assert_eq!(output.dims(), &[1, 2]); + + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert!(output_vals.iter().all(|&x| x.is_finite())); + + Ok(()) +} + +fn reglu_activation(device: &Device) -> Result<()> { + let input = Tensor::new(&[[-1.0f32, 2.0, 3.0, 4.0]], device)?; + let output = input.reglu()?; + + assert_eq!(output.dims(), &[1, 2]); + + // ReLU(-1) = 0, ReLU(2) = 2 + // output = [0 * 3, 2 * 4] = [0, 8] + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert_eq!(output_vals[0], 0.0); + assert_eq!(output_vals[1], 8.0); + + Ok(()) +} + +fn glu_odd_dimension_error(device: &Device) -> Result<()> { + let input = Tensor::new(&[[1.0f32, 2.0, 3.0]], device)?; + + // Should error with odd dimension + assert!(input.glu().is_err()); + assert!(input.geglu().is_err()); + assert!(input.reglu().is_err()); + + Ok(()) +} fn zeros(device: &Device) -> Result<()> { let tensor = Tensor::zeros((5, 2), DType::F32, device)?; let (dim1, dim2) = tensor.dims2()?; @@ -1693,6 +1744,30 @@ test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal); test_device!(asort, asort_cpu, asort_gpu, asort_metal); test_device!(var, var_cpu, var_gpu, var_metal); test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal); +test_device!( + glu_activation, + glu_activation_cpu, + glu_activation_gpu, + glu_activation_metal +); +test_device!( + geglu_activation, + geglu_activation_cpu, + geglu_activation_gpu, + geglu_activation_metal +); +test_device!( + reglu_activation, + reglu_activation_cpu, + reglu_activation_gpu, + reglu_activation_metal +); +test_device!( + glu_odd_dimension_error, + glu_odd_dimension_error_cpu, + glu_odd_dimension_error_gpu, + glu_odd_dimension_error_metal +); fn tensor_send_sync(device: &Device) -> Result<()> { let tensor = Tensor::new(vec![1.0f32, 2.0, 3.0], device)?; diff --git a/candle-nn/benches/bench_main.rs b/candle-nn/benches/bench_main.rs index 64d9b8b46e..a16d87cc17 100644 --- a/candle-nn/benches/bench_main.rs +++ b/candle-nn/benches/bench_main.rs @@ -4,5 +4,6 @@ use criterion::criterion_main; criterion_main!( benchmarks::softmax::benches, benchmarks::layer_norm::benches, - benchmarks::conv::benches + benchmarks::conv::benches, + benchmarks::activation_bench::benches, ); diff --git a/candle-nn/benches/benchmarks/activation_bench.rs b/candle-nn/benches/benchmarks/activation_bench.rs new file mode 100644 index 0000000000..33aebfab1e --- /dev/null +++ b/candle-nn/benches/benchmarks/activation_bench.rs @@ -0,0 +1,189 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle::{DType, Device, Module, Tensor}; +use candle_nn::Activation; +use criterion::{black_box, criterion_group, Criterion}; +use std::time::Instant; + +fn run_activation_benchmark( + c: &mut Criterion, + device: &Device, + dtype: DType, + activation: Activation, + name: &str, +) { + let sizes = [512, 1024, 2048, 4096, 8192]; + + for &size in &sizes { + let input = Tensor::randn(0f32, 1f32, (1, size), device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let bench_name = format!("{}_{}_{}", device.bench_name(name), dtype.as_str(), size); + + c.bench_function(&bench_name, |b| { + b.iter_custom(|iters| { + device.sync().unwrap(); + let start = Instant::now(); + for _i in 0..iters { + let _result = black_box(activation.forward(black_box(&input)).unwrap()); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + } +} + +fn run_core_tensor_benchmark( + c: &mut Criterion, + device: &Device, + dtype: DType, + name: &str, + activation_fn: F, +) where + F: Fn(&Tensor) -> candle::Result + Copy, +{ + let sizes = [512, 1024, 2048, 4096, 8192]; + + for &size in &sizes { + // For GLU variants, we need even dimensions (they split the input) + let input_size = if name.contains("glu") || name.contains("GLU") { + size * 2 // Double the size so after GLU we get 'size' output + } else { + size + }; + + let input = Tensor::randn(0f32, 1f32, (1, input_size), device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let bench_name = format!( + "{}_core_{}_{}", + device.bench_name(name), + dtype.as_str(), + size + ); + + c.bench_function(&bench_name, |b| { + b.iter_custom(|iters| { + device.sync().unwrap(); + let start = Instant::now(); + for _i in 0..iters { + let _result = black_box(activation_fn(black_box(&input)).unwrap()); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + } +} + +fn run_comparison_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let sizes = [1024, 2048, 4096]; + + for &size in &sizes { + let input = Tensor::randn(0f32, 1f32, (1, size * 2), device) + .unwrap() + .to_dtype(dtype) + .unwrap(); + + let bench_name = format!( + "{}_comparison_{}_{}", + device.bench_name(name), + dtype.as_str(), + size + ); + + // Create a benchmark group for direct comparison + let mut group = c.benchmark_group(&bench_name); + + // Benchmark via Activation enum + group.bench_function("enum", |b| { + let activation = match name { + "glu" => Activation::Glu, + "geglu" => Activation::GeGlu, + "reglu" => Activation::ReGlu, + _ => Activation::Glu, + }; + b.iter_custom(|iters| { + device.sync().unwrap(); + let start = Instant::now(); + for _i in 0..iters { + let _result = black_box(activation.forward(black_box(&input)).unwrap()); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + + // Benchmark via core tensor method + group.bench_function("core", |b| { + b.iter_custom(|iters| { + device.sync().unwrap(); + let start = Instant::now(); + for _i in 0..iters { + let _result = match name { + "glu" => black_box(black_box(&input).glu().unwrap()), + "geglu" => black_box(black_box(&input).geglu().unwrap()), + "reglu" => black_box(black_box(&input).reglu().unwrap()), + _ => black_box(black_box(&input).glu().unwrap()), + }; + } + device.sync().unwrap(); + start.elapsed() + }) + }); + + group.finish(); + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + + for device in handler.devices { + // Benchmark GLU variants via Activation enum + run_activation_benchmark(c, &device, DType::F32, Activation::Glu, "glu_enum_f32"); + run_activation_benchmark(c, &device, DType::F32, Activation::GeGlu, "geglu_enum_f32"); + run_activation_benchmark(c, &device, DType::F32, Activation::ReGlu, "reglu_enum_f32"); + + // Benchmark GLU variants via core tensor methods + run_core_tensor_benchmark(c, &device, DType::F32, "glu_core", |t| t.glu()); + run_core_tensor_benchmark(c, &device, DType::F32, "geglu_core", |t| t.geglu()); + run_core_tensor_benchmark(c, &device, DType::F32, "reglu_core", |t| t.reglu()); + + // Direct comparison benchmarks + run_comparison_benchmark(c, &device, DType::F32, "glu"); + run_comparison_benchmark(c, &device, DType::F32, "geglu"); + run_comparison_benchmark(c, &device, DType::F32, "reglu"); + + // Compare with existing activations (for context) + run_activation_benchmark(c, &device, DType::F32, Activation::Silu, "silu_f32"); + run_activation_benchmark(c, &device, DType::F32, Activation::Swiglu, "swiglu_f32"); + run_activation_benchmark(c, &device, DType::F32, Activation::Gelu, "gelu_f32"); + + // Core tensor equivalents for comparison + run_core_tensor_benchmark(c, &device, DType::F32, "silu_core", |t| t.silu()); + run_core_tensor_benchmark(c, &device, DType::F32, "gelu_core", |t| t.gelu()); + run_core_tensor_benchmark(c, &device, DType::F32, "relu_core", |t| t.relu()); + + // Test different data types for GLU variants + if !device.is_metal() { + run_core_tensor_benchmark(c, &device, DType::F64, "glu_core", |t| t.glu()); + run_core_tensor_benchmark(c, &device, DType::F64, "geglu_core", |t| t.geglu()); + run_core_tensor_benchmark(c, &device, DType::F64, "reglu_core", |t| t.reglu()); + } + + run_core_tensor_benchmark(c, &device, DType::F16, "glu_core", |t| t.glu()); + run_core_tensor_benchmark(c, &device, DType::F16, "geglu_core", |t| t.geglu()); + run_core_tensor_benchmark(c, &device, DType::F16, "reglu_core", |t| t.reglu()); + + run_core_tensor_benchmark(c, &device, DType::BF16, "glu_core", |t| t.glu()); + run_core_tensor_benchmark(c, &device, DType::BF16, "geglu_core", |t| t.geglu()); + run_core_tensor_benchmark(c, &device, DType::BF16, "reglu_core", |t| t.reglu()); + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-nn/benches/benchmarks/mod.rs b/candle-nn/benches/benchmarks/mod.rs index c1ebfa0f50..2d6cd2f9d4 100644 --- a/candle-nn/benches/benchmarks/mod.rs +++ b/candle-nn/benches/benchmarks/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod activation_bench; pub(crate) mod conv; pub(crate) mod layer_norm; pub(crate) mod softmax; diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index 430316b89d..ac4dcb76b4 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -7,8 +7,8 @@ extern crate accelerate_src; use candle::quantized::GgmlType; use candle::{CpuStorage, Device, Layout, Module, Result, Shape, Tensor, D}; +use candle_nn::Activation; use clap::{Parser, Subcommand}; - const CHECK_CONV2D: bool = false; trait Benchmark { @@ -21,6 +21,54 @@ trait Benchmark { const ITERS: usize; } +struct GluActivation; +impl Benchmark for GluActivation { + type PreProcessData = Tensor; + type RunResult = Tensor; + + fn preprocess() -> Result { + Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu) + } + + fn run_one(data: &Self::PreProcessData) -> Result { + Activation::Glu.forward(data) + } + + const ITERS: usize = 100; +} + +struct GeGluActivation; +impl Benchmark for GeGluActivation { + type PreProcessData = Tensor; + type RunResult = Tensor; + + fn preprocess() -> Result { + Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu) + } + + fn run_one(data: &Self::PreProcessData) -> Result { + Activation::GeGlu.forward(data) + } + + const ITERS: usize = 100; +} + +struct ReGluActivation; +impl Benchmark for ReGluActivation { + type PreProcessData = Tensor; + type RunResult = Tensor; + + fn preprocess() -> Result { + Tensor::randn(0f32, 1., (1024, 2048), &Device::Cpu) + } + + fn run_one(data: &Self::PreProcessData) -> Result { + Activation::ReGlu.forward(data) + } + + const ITERS: usize = 100; +} + struct Im2Col { h_k: usize, w_k: usize, @@ -313,6 +361,9 @@ enum Task { Softmax, SoftmaxLastDim, Cat, + GluActivation, + GeGluActivation, + ReGluActivation, } #[derive(Parser, Debug)] @@ -338,6 +389,9 @@ fn main() -> Result<()> { Task::SoftmaxLastDim => run::(args.iters)?, Task::Qmatmul => run::(args.iters)?, Task::Cat => run::(args.iters)?, + Task::GluActivation => run::(args.iters)?, + Task::GeGluActivation => run::(args.iters)?, + Task::ReGluActivation => run::(args.iters)?, } Ok(()) } diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index f2a992afcc..9d2ca70239 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -24,13 +24,24 @@ pub enum Activation { LeakyRelu(f64), #[serde(alias = "gelu_pytorch_tanh")] GeluPytorchTanh, + + // New GLU variants + /// Gated Linear Unit - splits input in half, applies sigmoid to one half, + /// multiplies with the other half. Commonly used in transformer FFNs. + #[serde(alias = "glu")] + Glu, + /// GeGLU - GLU variant using GELU instead of sigmoid + #[serde(alias = "geglu")] + GeGlu, + /// ReGLU - GLU variant using ReLU instead of sigmoid + #[serde(alias = "reglu")] + ReGlu, } impl super::Module for Activation { fn forward(&self, xs: &Tensor) -> Result { match self { Self::Gelu => xs.gelu_erf(), - // https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78 Self::NewGelu => xs.gelu(), Self::Relu => xs.relu(), Self::Relu2 => xs.relu()?.sqr(), @@ -45,6 +56,9 @@ impl super::Module for Activation { &Self::Elu(alpha) => xs.elu(alpha), &Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope), Self::GeluPytorchTanh => xs.gelu(), + Self::Glu => xs.glu(), + Self::GeGlu => xs.geglu(), + Self::ReGlu => xs.reglu(), } } } diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index cce6050806..c664c270d1 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -279,6 +279,22 @@ impl SimpleBackend for VarMap { self.data().lock().unwrap().contains_key(name) } } +impl SimpleBackend for crate::var_map::ConcurrentVarMap { + fn get( + &self, + s: Shape, + name: &str, + h: crate::Init, + dtype: DType, + dev: &Device, + ) -> Result { + self.get(s, name, h, dtype, dev) + } + + fn contains_tensor(&self, name: &str) -> bool { + self.contains_key(name) + } +} #[allow(dead_code)] pub struct SafeTensorWithRouting<'a> { @@ -466,6 +482,12 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> { } impl<'a> VarBuilder<'a> { + /// Initializes a `VarBuilder` using a custom backend. + /// + /// It is preferred to use one of the more specific constructors. This + /// constructor is provided to allow downstream users to define their own + /// backends. + /// Initializes a `VarBuilder` using a custom backend. /// /// It is preferred to use one of the more specific constructors. This diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs index ba020746b5..de8bc90020 100644 --- a/candle-nn/src/var_map.rs +++ b/candle-nn/src/var_map.rs @@ -2,36 +2,165 @@ //! use candle::{DType, Device, Result, Shape, Tensor, Var}; use std::collections::HashMap; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, RwLock}; +/// Storage backend trait for VarMap - allows different synchronization strategies +pub trait VarStorage: Send + Sync + Clone { + fn new() -> Self; + fn get_var(&self, name: &str) -> Option; + fn all_vars(&self) -> Vec; + fn insert_var(&self, name: String, var: Var); + fn contains_key(&self, name: &str) -> bool; + fn len(&self) -> usize; + fn iter_for_save(&self) -> Vec<(String, Var)>; + fn iter_for_load(&self) -> Vec<(String, Var)>; + fn iter_mut_for_load(&self) -> Vec<(String, Var)>; +} + +/// Original Mutex-based storage (for training) +#[derive(Clone)] +pub struct MutexStorage { + data: Arc>>, +} + +/// New RwLock-based storage (for concurrent inference) +#[derive(Clone)] +pub struct RwLockStorage { + data: Arc>>, +} +// Implementation for existing Mutex storage - maintains exact original behavior +impl VarStorage for MutexStorage { + fn new() -> Self { + Self { + data: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn get_var(&self, name: &str) -> Option { + let data = self.data.lock().unwrap(); + data.get(name).cloned() + } + + fn all_vars(&self) -> Vec { + let data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + data.values().map(|c| c.clone()).collect::>() + } + + fn insert_var(&self, name: String, var: Var) { + let mut data = self.data.lock().unwrap(); + data.insert(name, var); + } + + fn contains_key(&self, name: &str) -> bool { + let data = self.data.lock().unwrap(); + data.contains_key(name) + } + + fn len(&self) -> usize { + let data = self.data.lock().unwrap(); + data.len() + } + + fn iter_for_save(&self) -> Vec<(String, Var)> { + let data = self.data.lock().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } + + fn iter_for_load(&self) -> Vec<(String, Var)> { + let data = self.data.lock().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } + + fn iter_mut_for_load(&self) -> Vec<(String, Var)> { + let data = self.data.lock().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } +} + +// Implementation for RwLock storage +impl VarStorage for RwLockStorage { + fn new() -> Self { + Self { + data: Arc::new(RwLock::new(HashMap::new())), + } + } + + fn get_var(&self, name: &str) -> Option { + let data = self.data.read().unwrap(); + data.get(name).cloned() + } + + fn all_vars(&self) -> Vec { + let data = self.data.read().unwrap(); + #[allow(clippy::map_clone)] + data.values().map(|c| c.clone()).collect::>() + } + + fn insert_var(&self, name: String, var: Var) { + let mut data = self.data.write().unwrap(); + data.insert(name, var); + } + + fn contains_key(&self, name: &str) -> bool { + let data = self.data.read().unwrap(); + data.contains_key(name) + } + + fn len(&self) -> usize { + let data = self.data.read().unwrap(); + data.len() + } + + fn iter_for_save(&self) -> Vec<(String, Var)> { + let data = self.data.read().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } + + fn iter_for_load(&self) -> Vec<(String, Var)> { + let data = self.data.read().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } + + fn iter_mut_for_load(&self) -> Vec<(String, Var)> { + let data = self.data.read().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } +} + +// Generic VarMap implementation +#[derive(Clone)] +pub struct VarMapGeneric { + storage: Storage, +} +// Type aliases for easy usage /// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores /// and new variables can be added by providing some initialization config in case they are /// missing. /// `VarMap` structures can be serialized in the safetensors format. -#[derive(Clone)] -pub struct VarMap { - data: Arc>>, -} +pub type VarMap = VarMapGeneric; // Original (for training) -impl VarMap { +/// Concurrent version of VarMap using RwLock for better read performance in inference scenarios +pub type ConcurrentVarMap = VarMapGeneric; + +impl VarMapGeneric { /// Create a new empty `VarMap`. #[allow(clippy::new_without_default)] pub fn new() -> Self { - let data = Arc::new(Mutex::new(HashMap::new())); - Self { data } + Self { + storage: Storage::new(), + } } /// Retrieve all the variables currently stored in the map. pub fn all_vars(&self) -> Vec { - let tensor_data = self.data.lock().unwrap(); - #[allow(clippy::map_clone)] - tensor_data.values().map(|c| c.clone()).collect::>() + self.storage.all_vars() } /// Save the map in the safetensors format. pub fn save>(&self, path: P) -> Result<()> { - let tensor_data = self.data.lock().unwrap(); - let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor())); + let data = self.storage.iter_for_save(); + let data = data.iter().map(|(k, v)| (k, v.as_tensor())); safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; Ok(()) } @@ -43,11 +172,12 @@ impl VarMap { pub fn load>(&mut self, path: P) -> Result<()> { let path = path.as_ref(); let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? }; - let mut tensor_data = self.data.lock().unwrap(); - for (name, var) in tensor_data.iter_mut() { - let data = data.load(name, var.device())?; - if let Err(err) = var.set(&data) { - candle::bail!("error setting {name} using data from {path:?}: {err}",) + let vars = self.storage.iter_mut_for_load(); + + for (name, var) in vars { + let tensor_data = data.load(&name, var.device())?; + if let Err(err) = var.set(&tensor_data) { + candle::bail!("error setting {name} using data from {path:?}: {err}") } } Ok(()) @@ -55,13 +185,12 @@ impl VarMap { /// Set a named variable to some value. pub fn set_one, V: AsRef>(&mut self, name: K, value: V) -> Result<()> { - let tensor_data = self.data.lock().unwrap(); let name = name.as_ref(); - match tensor_data.get(name) { + match self.storage.get_var(name) { None => candle::bail!("cannot find {name} in VarMap"), Some(var) => { if let Err(err) = var.set(value.as_ref()) { - candle::bail!("error setting {name}: {err}",) + candle::bail!("error setting {name}: {err}") } } } @@ -76,14 +205,13 @@ impl VarMap { &mut self, iter: I, ) -> Result<()> { - let tensor_data = self.data.lock().unwrap(); for (name, value) in iter { let name = name.as_ref(); - match tensor_data.get(name) { + match self.storage.get_var(name) { None => candle::bail!("cannot find {name} in VarMap"), Some(var) => { if let Err(err) = var.set(value.as_ref()) { - candle::bail!("error setting {name}: {err}",) + candle::bail!("error setting {name}: {err}") } } } @@ -101,21 +229,72 @@ impl VarMap { device: &Device, ) -> Result { let shape = shape.into(); - let mut tensor_data = self.data.lock().unwrap(); - if let Some(tensor) = tensor_data.get(path) { - let tensor_shape = tensor.shape(); + if let Some(existing_var) = self.storage.get_var(path) { + let tensor_shape = existing_var.shape(); if &shape != tensor_shape { candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") } - return Ok(tensor.as_tensor().clone()); + return Ok(existing_var.as_tensor().clone()); } let var = init.var(shape, dtype, device)?; let tensor = var.as_tensor().clone(); - tensor_data.insert(path.to_string(), var); + self.storage.insert_var(path.to_string(), var); Ok(tensor) } - pub fn data(&self) -> &Mutex> { - &self.data + /// Get a variable by name (method for compatibility). + pub fn get_var(&self, name: &str) -> Option { + self.storage.get_var(name) + } + + /// Insert a new variable (method for compatibility). + pub fn insert(&self, name: String, var: Var) { + self.storage.insert_var(name, var); + } + + /// Check if a variable exists (method for compatibility). + pub fn contains_key(&self, name: &str) -> bool { + self.storage.contains_key(name) + } + + /// Convert to the other storage type (for migration) + pub fn into_concurrent(self) -> ConcurrentVarMap + where + Storage: VarStorage, + { + let concurrent = ConcurrentVarMap::new(); + + // Transfer all variables + for (name, var) in self.storage.iter_for_save() { + concurrent.insert(name, var); + } + + concurrent + } +} + +impl VarMap { + pub fn data(&self) -> &Arc>> { + &self.storage.data + } +} +impl ConcurrentVarMap { + pub fn read_data(&self) -> std::sync::RwLockReadGuard> { + self.storage.data.read().unwrap() + } + pub fn write_data(&self) -> std::sync::RwLockWriteGuard> { + self.storage.data.write().unwrap() + } + + pub fn get_vars_batch(&self, names: &[&str]) -> HashMap { + let data = self.storage.data.read().unwrap(); + names + .iter() + .filter_map(|&name| data.get(name).map(|v| (name.to_string(), v.clone()))) + .collect() + } + + pub fn data(&self) -> &Arc>> { + &self.storage.data } } diff --git a/candle-nn/tests/activations_added.rs b/candle-nn/tests/activations_added.rs new file mode 100644 index 0000000000..2b4fb1ab61 --- /dev/null +++ b/candle-nn/tests/activations_added.rs @@ -0,0 +1,299 @@ +use candle::{Device, Module, Result, Tensor}; +use candle_nn::Activation; + +#[test] +fn test_glu_activation() -> Result<()> { + let device = Device::Cpu; + + // Test GLU with even dimension (4 -> 2) + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], &device)?; + let glu = Activation::Glu; + let output = glu.forward(&input)?; + + // Expected: sigmoid([1, 2]) * [3, 4] + assert_eq!(output.dims(), &[1, 2]); + + // Verify output is finite and reasonable + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert!(output_vals.iter().all(|&x| x.is_finite())); + + println!("GLU output: {:?}", output_vals); + Ok(()) +} + +#[test] +fn test_glu_odd_dimension_error() { + let device = Device::Cpu; + let input = Tensor::new(&[[1.0f32, 2.0, 3.0]], &device).unwrap(); + let glu = Activation::Glu; + + // Should error with odd dimension + let result = glu.forward(&input); + assert!(result.is_err(), "GLU should error with odd dimension"); + + // Check error message contains expected text + let error_msg = format!("{}", result.unwrap_err()); + assert!( + error_msg.contains("even"), + "Error should mention even dimension requirement" + ); + println!("GLU correctly rejects odd dimensions: {}", error_msg); +} + +#[test] +fn test_geglu_activation() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], &device)?; + let geglu = Activation::GeGlu; + let output = geglu.forward(&input)?; + + assert_eq!(output.dims(), &[1, 2]); + + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert!(output_vals.iter().all(|&x| x.is_finite())); + + println!("GeGLU output: {:?}", output_vals); + Ok(()) +} + +#[test] +fn test_reglu_activation() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[-1.0f32, 2.0, 3.0, 4.0]], &device)?; + let reglu = Activation::ReGlu; + let output = reglu.forward(&input)?; + + assert_eq!(output.dims(), &[1, 2]); + + // ReLU(-1) = 0, ReLU(2) = 2 + // output = [0 * 3, 2 * 4] = [0, 8] + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + assert_eq!(output_vals[0], 0.0); + assert_eq!(output_vals[1], 8.0); + + println!("ReGLU output: {:?}", output_vals); + Ok(()) +} + +#[test] +fn test_multidimensional_glu() -> Result<()> { + let device = Device::Cpu; + // Test with 3D tensor (batch_size=2, seq_len=3, hidden_dim=4) + let input = Tensor::randn(0f32, 1f32, (2, 3, 4), &device)?; + let glu = Activation::Glu; + let output = glu.forward(&input)?; + + // Should halve the last dimension: (2, 3, 4) -> (2, 3, 2) + assert_eq!(output.dims(), &[2, 3, 2]); + + println!( + "Multidimensional GLU: {:?} -> {:?}", + input.dims(), + output.dims() + ); + Ok(()) +} + +#[test] +fn test_phi3_compatibility() -> Result<()> { + let device = Device::Cpu; + + // Test that GLU variants work with typical transformer dimensions + let transformer_input = Tensor::randn(0f32, 1f32, (2, 128, 2048), &device)?; // (batch, seq, hidden*2) + + let glu = Activation::Glu; + let output = glu.forward(&transformer_input)?; + + // Should halve last dimension: (2, 128, 2048) -> (2, 128, 1024) + assert_eq!(output.dims(), &[2, 128, 1024]); + + println!( + "Phi-3 compatibility: {:?} -> {:?}", + transformer_input.dims(), + output.dims() + ); + Ok(()) +} + +#[test] +fn test_glu_variants_comparison() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], &device)?; + + let glu_output = Activation::Glu.forward(&input)?; + let geglu_output = Activation::GeGlu.forward(&input)?; + let reglu_output = Activation::ReGlu.forward(&input)?; + + // All should have same output shape + assert_eq!(glu_output.dims(), &[1, 2]); + assert_eq!(geglu_output.dims(), &[1, 2]); + assert_eq!(reglu_output.dims(), &[1, 2]); + + // Values should be different due to different gating functions + let glu_vals: Vec = glu_output.flatten_all()?.to_vec1()?; + let geglu_vals: Vec = geglu_output.flatten_all()?.to_vec1()?; + let reglu_vals: Vec = reglu_output.flatten_all()?.to_vec1()?; + + println!("GLU values: {:?}", glu_vals); + println!("GeGLU values: {:?}", geglu_vals); + println!("ReGLU values: {:?}", reglu_vals); + + // GLU and GeGLU should have different values (sigmoid vs GELU) + assert_ne!(glu_vals, geglu_vals); + // GLU and ReGLU should have different values (sigmoid vs ReLU) + assert_ne!(glu_vals, reglu_vals); + + Ok(()) +} + +#[test] +fn test_glu_variants_with_negative_inputs() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[-2.0f32, -1.0, 3.0, 4.0]], &device)?; + + // Test that all variants handle negative inputs correctly + let glu_output = Activation::Glu.forward(&input)?; + let geglu_output = Activation::GeGlu.forward(&input)?; + let reglu_output = Activation::ReGlu.forward(&input)?; + + assert_eq!(glu_output.dims(), &[1, 2]); + assert_eq!(geglu_output.dims(), &[1, 2]); + assert_eq!(reglu_output.dims(), &[1, 2]); + + let glu_vals: Vec = glu_output.flatten_all()?.to_vec1()?; + let geglu_vals: Vec = geglu_output.flatten_all()?.to_vec1()?; + let reglu_vals: Vec = reglu_output.flatten_all()?.to_vec1()?; + + println!("Negative input GLU: {:?}", glu_vals); + println!("Negative input GeGLU: {:?}", geglu_vals); + println!("Negative input ReGLU: {:?}", reglu_vals); + + // All should produce finite values + assert!(glu_vals.iter().all(|&x| x.is_finite())); + assert!(geglu_vals.iter().all(|&x| x.is_finite())); + assert!(reglu_vals.iter().all(|&x| x.is_finite())); + + Ok(()) +} + +#[test] +fn test_core_vs_enum_consistency() -> Result<()> { + let device = Device::Cpu; + let input = Tensor::new(&[[1.0f32, 2.0, 3.0, 4.0]], &device)?; + + // Test that core tensor methods match activation enum + let core_glu = input.glu()?; + let enum_glu = Activation::Glu.forward(&input)?; + + let core_geglu = input.geglu()?; + let enum_geglu = Activation::GeGlu.forward(&input)?; + + let core_reglu = input.reglu()?; + let enum_reglu = Activation::ReGlu.forward(&input)?; + + // Compare outputs (allowing for small floating point differences) + let core_glu_vals: Vec = core_glu.flatten_all()?.to_vec1()?; + let enum_glu_vals: Vec = enum_glu.flatten_all()?.to_vec1()?; + + let core_geglu_vals: Vec = core_geglu.flatten_all()?.to_vec1()?; + let enum_geglu_vals: Vec = enum_geglu.flatten_all()?.to_vec1()?; + + let core_reglu_vals: Vec = core_reglu.flatten_all()?.to_vec1()?; + let enum_reglu_vals: Vec = enum_reglu.flatten_all()?.to_vec1()?; + + // GLU consistency + for (core_val, enum_val) in core_glu_vals.iter().zip(enum_glu_vals.iter()) { + assert!( + (core_val - enum_val).abs() < 1e-6, + "GLU core vs enum mismatch: {} vs {}", + core_val, + enum_val + ); + } + + // GeGLU consistency + for (core_val, enum_val) in core_geglu_vals.iter().zip(enum_geglu_vals.iter()) { + assert!( + (core_val - enum_val).abs() < 1e-6, + "GeGLU core vs enum mismatch: {} vs {}", + core_val, + enum_val + ); + } + + // ReGLU consistency + for (core_val, enum_val) in core_reglu_vals.iter().zip(enum_reglu_vals.iter()) { + assert!( + (core_val - enum_val).abs() < 1e-6, + "ReGLU core vs enum mismatch: {} vs {}", + core_val, + enum_val + ); + } + + println!("Core vs Enum consistency test passed for all GLU variants"); + Ok(()) +} + +#[test] +fn test_glu_performance_characteristics() -> Result<()> { + let device = Device::Cpu; + + // Test different sizes to verify linear scaling + let sizes = vec![8, 16, 32, 64]; + + for size in sizes { + let input = Tensor::randn(0f32, 1f32, (1, size), &device)?; + + let glu_output = Activation::Glu.forward(&input)?; + let geglu_output = Activation::GeGlu.forward(&input)?; + let reglu_output = Activation::ReGlu.forward(&input)?; + + // All should halve the input size + assert_eq!(glu_output.dims(), &[1, size / 2]); + assert_eq!(geglu_output.dims(), &[1, size / 2]); + assert_eq!(reglu_output.dims(), &[1, size / 2]); + + println!( + "Size {}: All GLU variants produce correct output dimensions", + size + ); + } + + Ok(()) +} + +#[test] +fn test_glu_gradient_flow() -> Result<()> { + let device = Device::Cpu; + + // Test that GLU variants allow proper gradient flow + let input = Tensor::randn(0f32, 1f32, (2, 8), &device)?; + + let activations = vec![ + ("GLU", Activation::Glu), + ("GeGLU", Activation::GeGlu), + ("ReGLU", Activation::ReGlu), + ]; + + for (name, activation) in activations { + let output = activation.forward(&input)?; + + // Verify output is differentiable (not zero everywhere) + let output_vals: Vec = output.flatten_all()?.to_vec1()?; + let non_zero_count = output_vals.iter().filter(|&&x| x.abs() > 1e-6).count(); + + assert!( + non_zero_count > 0, + "Activation {} produced all-zero output", + name + ); + + println!( + "{} gradient flow test passed ({} non-zero values)", + name, non_zero_count + ); + } + + Ok(()) +} diff --git a/candle-nn/tests/var_map_compatibility.rs b/candle-nn/tests/var_map_compatibility.rs new file mode 100644 index 0000000000..0db951020b --- /dev/null +++ b/candle-nn/tests/var_map_compatibility.rs @@ -0,0 +1,488 @@ +use candle::{DType, Device, Result, Tensor, Var}; +use candle_nn::var_map::ConcurrentVarMap; +use candle_nn::{Init, VarMap}; +use std::sync::{Arc, Barrier}; +use std::thread; + +#[test] +fn test_basic_operations_compatibility() -> Result<()> { + let device = Device::Cpu; + + // Original implementation + let original = { + #[derive(Clone)] + struct OriginalVarMap { + data: Arc>>, + } + + impl OriginalVarMap { + fn new() -> Self { + Self { + data: Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())), + } + } + + fn get>( + &self, + shape: S, + path: &str, + init: Init, + dtype: DType, + device: &Device, + ) -> Result { + let shape = shape.into(); + let mut tensor_data = self.data.lock().unwrap(); + if let Some(tensor) = tensor_data.get(path) { + let tensor_shape = tensor.shape(); + if &shape != tensor_shape { + candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(tensor.as_tensor().clone()); + } + let var = init.var(shape, dtype, device)?; + let tensor = var.as_tensor().clone(); + tensor_data.insert(path.to_string(), var); + Ok(tensor) + } + + fn all_vars(&self) -> Vec { + let tensor_data = self.data.lock().unwrap(); + tensor_data.values().cloned().collect() + } + } + + OriginalVarMap::new() + }; + + // New implementation + let updated = VarMap::new(); + + // Test 1: Basic get operations + let t1_orig = original.get( + (2, 3), + "test1", + Init::Randn { + mean: 0., + stdev: 1., + }, + DType::F32, + &device, + )?; + let t1_updated = updated.get( + (2, 3), + "test1", + Init::Randn { + mean: 0., + stdev: 1., + }, + DType::F32, + &device, + )?; + + // Shapes should match + assert_eq!(t1_orig.shape(), t1_updated.shape()); + + // Test 2: Repeated get returns same variable + let t1_orig_2 = original.get((2, 3), "test1", Init::Const(0.), DType::F32, &device)?; + let t1_updated_2 = updated.get((2, 3), "test1", Init::Const(0.), DType::F32, &device)?; + + // Should return existing variables + assert_eq!(t1_orig.shape(), t1_orig_2.shape()); + assert_eq!(t1_updated.shape(), t1_updated_2.shape()); + + // Test 3: Multiple variables + for i in 0..10 { + let name = format!("var_{}", i); + original.get( + (i + 1, i + 2), + &name, + Init::Const(i as f64), + DType::F32, + &device, + )?; + updated.get( + (i + 1, i + 2), + &name, + Init::Const(i as f64), + DType::F32, + &device, + )?; + } + + // Verify all variables match + assert_eq!(original.all_vars().len(), updated.all_vars().len()); + + Ok(()) +} + +#[test] +fn test_concurrent_reads_match_sequential() -> Result<()> { + let device = Device::Cpu; + let updated = Arc::new(VarMap::new()); + let concurrent = Arc::new(ConcurrentVarMap::new()); + + // Initialize both with same data + for i in 0..100 { + let name = format!("var_{}", i); + let shape = (10, 10); + let init = Init::Const(i as f64); + + updated.get(shape, &name, init, DType::F32, &device)?; + concurrent.get(shape, &name, init, DType::F32, &device)?; + } + + // Test concurrent reads + let n_threads = 8; + let barrier = Arc::new(Barrier::new(n_threads)); + let mut handles = vec![]; + + for _thread_id in 0..n_threads { + let updated_clone: Arc = Arc::clone(&updated); + let concurrent_clone: Arc = Arc::clone(&concurrent); + let barrier_clone = Arc::clone(&barrier); + let device_clone = device.clone(); + + let handle = thread::spawn(move || { + barrier_clone.wait(); + + // Each thread reads multiple variables + for i in 0..100 { + let name = format!("var_{}", i); + let shape = (10, 10); + + let v1 = updated_clone + .get(shape, &name, Init::Const(0.), DType::F32, &device_clone) + .unwrap(); + let v2 = concurrent_clone + .get(shape, &name, Init::Const(0.), DType::F32, &device_clone) + .unwrap(); + + // Values should match + assert_eq!(v1.shape(), v2.shape()); + + // Compare flattened data for any shape + let data1 = v1.flatten_all().unwrap().to_vec1::().unwrap(); + let data2 = v2.flatten_all().unwrap().to_vec1::().unwrap(); + assert_eq!(data1, data2); + } + }); + + handles.push(handle); + } + + // Wait for all threads + for handle in handles { + handle.join().unwrap(); + } + + Ok(()) +} + +#[test] +fn test_save_load_compatibility() -> Result<()> { + let device = Device::Cpu; + let original = VarMap::new(); + let updated = VarMap::new(); + + // Create identical data + for i in 0..20 { + let name = format!("layer_{}.weight", i); + let shape = (64, 64); + // Use a deterministic init for comparison + original.get(shape, &name, Init::Const(i as f64), DType::F32, &device)?; + updated.get(shape, &name, Init::Const(i as f64), DType::F32, &device)?; + } + + // Save both + let original_path = "/tmp/test_original_varmap.safetensors"; + let updated_path = "/tmp/test_updated_varmap.safetensors"; + + original.save(original_path)?; + updated.save(updated_path)?; + + // Files should be identical + let original_bytes = std::fs::read(original_path)?; + let updated_bytes = std::fs::read(updated_path)?; + assert_eq!(original_bytes, updated_bytes, "Saved files differ!"); + + // Test loading + let mut original_loaded = VarMap::new(); + let mut updated_loaded = VarMap::new(); + + // Pre-create variables for loading + for i in 0..20 { + let name = format!("layer_{}.weight", i); + original_loaded.get((64, 64), &name, Init::Const(0.), DType::F32, &device)?; + updated_loaded.get((64, 64), &name, Init::Const(0.), DType::F32, &device)?; + } + + original_loaded.load(original_path)?; + updated_loaded.load(updated_path)?; + + // Verify loaded data matches - check a few specific variables + for i in 0..20 { + let name = format!("layer_{}.weight", i); + let orig_var = + original_loaded.get((64, 64), &name, Init::Const(0.), DType::F32, &device)?; + let updated_var = + updated_loaded.get((64, 64), &name, Init::Const(0.), DType::F32, &device)?; + + // Compare shapes + assert_eq!(orig_var.shape(), updated_var.shape()); + + // Compare values - flatten first + let orig_data: Vec = orig_var.flatten_all()?.to_vec1()?; + let updated_data: Vec = updated_var.flatten_all()?.to_vec1()?; + + // Values should be close to i (the const value we used) + for (o, u) in orig_data.iter().zip(updated_data.iter()) { + assert!((o - u).abs() < 1e-6, "Value mismatch in {}", name); + } + } + + // Cleanup + std::fs::remove_file(original_path).ok(); + std::fs::remove_file(updated_path).ok(); + + Ok(()) +} + +#[test] +fn test_set_operations_compatibility() -> Result<()> { + let device = Device::Cpu; + let mut original = VarMap::new(); + let mut updated = VarMap::new(); + + // Initialize with same data + for i in 0..10 { + let name = format!("param_{}", i); + original.get((5, 5), &name, Init::Const(0.), DType::F32, &device)?; + updated.get((5, 5), &name, Init::Const(0.), DType::F32, &device)?; + } + + // Test set_one + let new_value = Tensor::ones((5, 5), DType::F32, &device)?; + original.set_one("param_0", &new_value)?; + updated.set_one("param_0", &new_value)?; + + // Test set with iterator + let updates: Vec<(String, Tensor)> = (1..5) + .map(|i| { + let name = format!("param_{}", i); + let value = Tensor::full(i as f32, (5, 5), &device).unwrap(); + (name, value) + }) + .collect(); + + original.set(updates.iter().map(|(k, v)| (k, v)))?; + updated.set(updates.iter().map(|(k, v)| (k, v)))?; + + // Verify specific values match + for i in 0..5 { + let name = format!("param_{}", i); + let orig_tensor = original.get((5, 5), &name, Init::Const(0.), DType::F32, &device)?; + let updated_tensor = updated.get((5, 5), &name, Init::Const(0.), DType::F32, &device)?; + + // Flatten and compare + let orig_data: Vec = orig_tensor.flatten_all()?.to_vec1()?; + let updated_data: Vec = updated_tensor.flatten_all()?.to_vec1()?; + + let expected_val = if i == 0 { 1.0 } else { i as f32 }; + + for (o, u) in orig_data.iter().zip(updated_data.iter()) { + assert!( + (o - expected_val).abs() < 1e-6, + "Original value mismatch for {}", + name + ); + assert!( + (u - expected_val).abs() < 1e-6, + "Updated value mismatch for {}", + name + ); + assert!((o - u).abs() < 1e-6, "Values don't match for {}", name); + } + } + + // Verify unchanged values + for i in 5..10 { + let name = format!("param_{}", i); + let orig_tensor = original.get((5, 5), &name, Init::Const(0.), DType::F32, &device)?; + let updated_tensor = updated.get((5, 5), &name, Init::Const(0.), DType::F32, &device)?; + + let orig_data: Vec = orig_tensor.flatten_all()?.to_vec1()?; + let updated_data: Vec = updated_tensor.flatten_all()?.to_vec1()?; + + // These should still be 0 + for (o, u) in orig_data.iter().zip(updated_data.iter()) { + assert!( + o.abs() < 1e-6, + "Original unchanged value not zero for {}", + name + ); + assert!( + u.abs() < 1e-6, + "Updated unchanged value not zero for {}", + name + ); + } + } + + Ok(()) +} + +#[test] +fn test_error_conditions_match() -> Result<()> { + let device = Device::Cpu; + let mut original = VarMap::new(); + let mut updated = VarMap::new(); + + // Test shape mismatch error + original.get((2, 3), "test", Init::Const(0.), DType::F32, &device)?; + updated.get((2, 3), "test", Init::Const(0.), DType::F32, &device)?; + + // Both should fail with shape mismatch + let orig_err = original.get((3, 2), "test", Init::Const(0.), DType::F32, &device); + let updated_err = updated.get((3, 2), "test", Init::Const(0.), DType::F32, &device); + + assert!(orig_err.is_err()); + assert!(updated_err.is_err()); + + // Test set_one on non-existent variable + let tensor = Tensor::ones((2, 2), DType::F32, &device)?; + let orig_err = original.set_one("nonexistent", &tensor); + let updated_err = updated.set_one("nonexistent", &tensor); + + assert!(orig_err.is_err()); + assert!(updated_err.is_err()); + + Ok(()) +} + +#[test] +fn test_concurrent_varmap_specific_features() -> Result<()> { + let device = Device::Cpu; + let concurrent = ConcurrentVarMap::new(); + + // Initialize data + for i in 0..50 { + let name = format!("weight_{}", i); + concurrent.get( + (32, 32), + &name, + Init::Randn { + mean: 0., + stdev: 0.02, + }, + DType::F32, + &device, + )?; + } + + // Test batch operations + let names: Vec<&str> = (0..10) + .map(|i| Box::leak(format!("weight_{}", i).into_boxed_str()) as &str) + .collect(); + let batch_vars = concurrent.get_vars_batch(&names); + + assert_eq!(batch_vars.len(), 10); + for (_name, var) in batch_vars { + assert_eq!(var.shape().dims(), &[32, 32]); + } + + // Test concurrent read access + let n_readers = 10; + let barrier = Arc::new(Barrier::new(n_readers)); + let concurrent: Arc = Arc::new(concurrent); + + let handles: Vec<_> = (0..n_readers) + .map(|_| { + let concurrent: Arc = Arc::clone(&concurrent); + let barrier = Arc::clone(&barrier); + + thread::spawn(move || { + barrier.wait(); + + // Multiple concurrent reads + let _guard = concurrent.read_data(); + thread::sleep(std::time::Duration::from_millis(10)); + + // Should not block other readers + assert!(concurrent.all_vars().len() >= 50); + }) + }) + .collect(); + + for handle in handles { + handle.join().unwrap(); + } + + Ok(()) +} + +#[test] +fn test_varmap_conversion() -> Result<()> { + let device = Device::Cpu; + let original = VarMap::new(); + + // Add some data + for i in 0..25 { + let name = format!("conv_{}.weight", i); + original.get( + (3, 3, 64, 64), + &name, + Init::Kaiming { + dist: candle_nn::init::NormalOrUniform::Normal, + fan: candle_nn::init::FanInOut::FanIn, + non_linearity: candle_nn::init::NonLinearity::ReLU, + }, + DType::F32, + &device, + )?; + } + + // Convert to concurrent + let concurrent = original.clone().into_concurrent(); + + // Verify all data transferred + assert_eq!(original.all_vars().len(), concurrent.all_vars().len()); + + // Verify values match + let orig_vars = original.all_vars(); + let conc_vars = concurrent.all_vars(); + + for (orig, conc) in orig_vars.iter().zip(conc_vars.iter()) { + assert_eq!(orig.shape(), conc.shape()); + assert_eq!(orig.dtype(), conc.dtype()); + } + + Ok(()) +} + +#[test] +fn test_backend_trait_implementation() -> Result<()> { + use candle_nn::VarBuilder; + + let device = Device::Cpu; + + // Test that VarMap works as SimpleBackend + let varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + + // Create some layers + let weight1 = vb.get((128, 256), "layer1.weight")?; + let weight2 = vb.get((256, 512), "layer2.weight")?; + + assert_eq!(weight1.shape().dims(), &[128, 256]); + assert_eq!(weight2.shape().dims(), &[256, 512]); + + // Test ConcurrentVarMap as backend + let concurrent = ConcurrentVarMap::new(); + let vb_concurrent = + VarBuilder::from_backend(Box::new(concurrent.clone()), DType::F32, device.clone()); + + let weight3 = vb_concurrent.get((64, 128), "layer3.weight")?; + assert_eq!(weight3.shape().dims(), &[64, 128]); + + Ok(()) +} diff --git a/candle-nn/tests/var_map_integration.rs b/candle-nn/tests/var_map_integration.rs new file mode 100644 index 0000000000..2daacc2d02 --- /dev/null +++ b/candle-nn/tests/var_map_integration.rs @@ -0,0 +1,36 @@ +use candle::{DType, Device, Module, Result, Tensor}; +use candle_nn::var_map::ConcurrentVarMap; +use candle_nn::{VarBuilder, VarMap}; + +#[test] +fn test_with_neural_network_layers() -> Result<()> { + let device = Device::Cpu; + + // Test with original VarMap + let varmap1 = VarMap::new(); + let vb1 = VarBuilder::from_varmap(&varmap1, DType::F32, &device); + let layer1 = candle_nn::linear(768, 512, vb1.pp("layer1"))?; + + // Test with updated VarMap + let varmap2 = VarMap::new(); + let vb2 = VarBuilder::from_varmap(&varmap2, DType::F32, &device); + let layer2 = candle_nn::linear(768, 512, vb2.pp("layer1"))?; + + // Test with ConcurrentVarMap - now we need to handle it differently + // since from_varmap expects VarMap specifically + let varmap3 = ConcurrentVarMap::new(); + let vb3 = VarBuilder::from_backend(Box::new(varmap3.clone()), DType::F32, device.clone()); + let layer3 = candle_nn::linear(768, 512, vb3.pp("layer1"))?; + + // All should work identically + let input = Tensor::randn(0f32, 1f32, (32, 768), &device)?; + + let out1 = layer1.forward(&input)?; + let out2 = layer2.forward(&input)?; + let out3 = layer3.forward(&input)?; + + assert_eq!(out1.shape(), out2.shape()); + assert_eq!(out2.shape(), out3.shape()); + + Ok(()) +} diff --git a/candle-nn/tests/var_map_stress.rs b/candle-nn/tests/var_map_stress.rs new file mode 100644 index 0000000000..714f3f201b --- /dev/null +++ b/candle-nn/tests/var_map_stress.rs @@ -0,0 +1,301 @@ +//! A `VarMap` is a store that holds named variables. +//! +use candle::{DType, Device, Result, Shape, Tensor, Var}; +use candle_nn::Init; +use std::collections::HashMap; +use std::sync::{Arc, Mutex, RwLock}; + +/// Storage backend trait for VarMap - allows different synchronization strategies +pub trait VarStorage: Send + Sync + Clone { + fn new() -> Self; + fn get_var(&self, name: &str) -> Option; + fn all_vars(&self) -> Vec; + fn insert_var(&self, name: String, var: Var); + fn contains_key(&self, name: &str) -> bool; + fn len(&self) -> usize; + fn iter_for_save(&self) -> Vec<(String, Var)>; + fn iter_for_load(&self) -> Vec<(String, Var)>; + fn iter_mut_for_load(&self) -> Vec<(String, Var)>; +} + +/// Original Mutex-based storage (for training) +#[derive(Clone)] +pub struct MutexStorage { + data: Arc>>, +} + +/// New RwLock-based storage (for concurrent inference) +#[derive(Clone)] +pub struct RwLockStorage { + data: Arc>>, +} +// Implementation for existing Mutex storage - maintains exact original behavior +impl VarStorage for MutexStorage { + fn new() -> Self { + Self { + data: Arc::new(Mutex::new(HashMap::new())), + } + } + + fn get_var(&self, name: &str) -> Option { + let data = self.data.lock().unwrap(); + data.get(name).cloned() + } + + fn all_vars(&self) -> Vec { + let data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + data.values().map(|c| c.clone()).collect::>() + } + + fn insert_var(&self, name: String, var: Var) { + let mut data = self.data.lock().unwrap(); + data.insert(name, var); + } + + fn contains_key(&self, name: &str) -> bool { + let data = self.data.lock().unwrap(); + data.contains_key(name) + } + + fn len(&self) -> usize { + let data = self.data.lock().unwrap(); + data.len() + } + + fn iter_for_save(&self) -> Vec<(String, Var)> { + let data = self.data.lock().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } + + fn iter_for_load(&self) -> Vec<(String, Var)> { + let data = self.data.lock().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } + + fn iter_mut_for_load(&self) -> Vec<(String, Var)> { + let data = self.data.lock().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } +} + +// Implementation for RwLock storage +impl VarStorage for RwLockStorage { + fn new() -> Self { + Self { + data: Arc::new(RwLock::new(HashMap::new())), + } + } + + fn get_var(&self, name: &str) -> Option { + let data = self.data.read().unwrap(); + data.get(name).cloned() + } + + fn all_vars(&self) -> Vec { + let data = self.data.read().unwrap(); + #[allow(clippy::map_clone)] + data.values().map(|c| c.clone()).collect::>() + } + + fn insert_var(&self, name: String, var: Var) { + let mut data = self.data.write().unwrap(); + data.insert(name, var); + } + + fn contains_key(&self, name: &str) -> bool { + let data = self.data.read().unwrap(); + data.contains_key(name) + } + + fn len(&self) -> usize { + let data = self.data.read().unwrap(); + data.len() + } + + fn iter_for_save(&self) -> Vec<(String, Var)> { + let data = self.data.read().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } + + fn iter_for_load(&self) -> Vec<(String, Var)> { + let data = self.data.read().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } + + fn iter_mut_for_load(&self) -> Vec<(String, Var)> { + let data = self.data.read().unwrap(); + data.iter().map(|(k, v)| (k.clone(), v.clone())).collect() + } +} + +// Generic VarMap implementation +#[derive(Clone)] +pub struct VarMapGeneric { + storage: Storage, +} +// Type aliases for easy usage +/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores +/// and new variables can be added by providing some initialization config in case they are +/// missing. +/// `VarMap` structures can be serialized in the safetensors format. +pub type VarMap = VarMapGeneric; // Original (for training) + +/// Concurrent version of VarMap using RwLock for better read performance in inference scenarios +pub type ConcurrentVarMap = VarMapGeneric; + +impl VarMapGeneric { + /// Create a new empty `VarMap`. + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Self { + storage: Storage::new(), + } + } + + /// Retrieve all the variables currently stored in the map. + pub fn all_vars(&self) -> Vec { + self.storage.all_vars() + } + + /// Save the map in the safetensors format. + pub fn save>(&self, path: P) -> Result<()> { + let data = self.storage.iter_for_save(); + let data = data.iter().map(|(k, v)| (k, v.as_tensor())); + safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?; + Ok(()) + } + + /// Load some values from a safetensors file and modify the existing variables to have these + /// values. + /// + /// Note that values for variables that are currently not in the map are not kept. + pub fn load>(&mut self, path: P) -> Result<()> { + let path = path.as_ref(); + let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? }; + let vars = self.storage.iter_mut_for_load(); + + for (name, var) in vars { + let tensor_data = data.load(&name, var.device())?; + if let Err(err) = var.set(&tensor_data) { + candle::bail!("error setting {name} using data from {path:?}: {err}") + } + } + Ok(()) + } + + /// Set a named variable to some value. + pub fn set_one, V: AsRef>(&mut self, name: K, value: V) -> Result<()> { + let name = name.as_ref(); + match self.storage.get_var(name) { + None => candle::bail!("cannot find {name} in VarMap"), + Some(var) => { + if let Err(err) = var.set(value.as_ref()) { + candle::bail!("error setting {name}: {err}") + } + } + } + Ok(()) + } + + /// Set some named variables to some values. + /// + /// If an error is returned, some of the variables might have already been set to their new + /// values. + pub fn set, K: AsRef, V: AsRef>( + &mut self, + iter: I, + ) -> Result<()> { + for (name, value) in iter { + let name = name.as_ref(); + match self.storage.get_var(name) { + None => candle::bail!("cannot find {name} in VarMap"), + Some(var) => { + if let Err(err) = var.set(value.as_ref()) { + candle::bail!("error setting {name}: {err}") + } + } + } + } + Ok(()) + } + + /// Retrieve or add a new variable. + pub fn get>( + &self, + shape: S, + path: &str, + init: crate::Init, + dtype: DType, + device: &Device, + ) -> Result { + let shape = shape.into(); + if let Some(existing_var) = self.storage.get_var(path) { + let tensor_shape = existing_var.shape(); + if &shape != tensor_shape { + candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(existing_var.as_tensor().clone()); + } + let var = init.var(shape, dtype, device)?; + let tensor = var.as_tensor().clone(); + self.storage.insert_var(path.to_string(), var); + Ok(tensor) + } + + /// Get a variable by name (method for compatibility). + pub fn get_var(&self, name: &str) -> Option { + self.storage.get_var(name) + } + + /// Insert a new variable (method for compatibility). + pub fn insert(&self, name: String, var: Var) { + self.storage.insert_var(name, var); + } + + /// Check if a variable exists (method for compatibility). + pub fn contains_key(&self, name: &str) -> bool { + self.storage.contains_key(name) + } + + /// Convert to the other storage type (for migration) + pub fn into_concurrent(self) -> ConcurrentVarMap + where + Storage: VarStorage, + { + let concurrent = ConcurrentVarMap::new(); + + // Transfer all variables + for (name, var) in self.storage.iter_for_save() { + concurrent.insert(name, var); + } + + concurrent + } +} + +impl VarMap { + pub fn data(&self) -> &Arc>> { + &self.storage.data + } +} +impl ConcurrentVarMap { + pub fn read_data(&self) -> std::sync::RwLockReadGuard> { + self.storage.data.read().unwrap() + } + pub fn write_data(&self) -> std::sync::RwLockWriteGuard> { + self.storage.data.write().unwrap() + } + + pub fn get_vars_batch(&self, names: &[&str]) -> HashMap { + let data = self.storage.data.read().unwrap(); + names + .iter() + .filter_map(|&name| data.get(name).map(|v| (name.to_string(), v.clone()))) + .collect() + } + + pub fn data(&self) -> &Arc>> { + &self.storage.data + } +} diff --git a/candle-transformers/src/models/phi3.rs b/candle-transformers/src/models/phi3.rs index 6535d9a4fd..9f4fe28b68 100644 --- a/candle-transformers/src/models/phi3.rs +++ b/candle-transformers/src/models/phi3.rs @@ -21,7 +21,7 @@ // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm}; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::VarBuilder; +use candle_nn::{Activation, VarBuilder}; use std::sync::Arc; #[derive(Debug, Clone, serde::Deserialize)] @@ -59,12 +59,41 @@ pub struct Config { #[serde(default)] pub tie_word_embeddings: bool, } - impl Config { + pub fn mini_4k_instruct() -> Self { + Self { + vocab_size: 32064, + hidden_act: Activation::GeGlu, + hidden_size: 3072, + intermediate_size: 8192, + num_hidden_layers: 32, + num_attention_heads: 32, + num_key_value_heads: 32, + rms_norm_eps: 1e-5, + rope_theta: 10000.0, + bos_token_id: Some(1), + eos_token_id: Some(2), + rope_scaling: None, + max_position_embeddings: 4096, + original_max_position_embeddings: None, + partial_rotary_factor: None, + tie_word_embeddings: false, + } + } + + pub fn with_activation(mut self, activation: Activation) -> Self { + self.hidden_act = activation; + self + } pub fn head_dim(&self) -> usize { self.hidden_size / self.num_attention_heads } } +// impl Config { +// pub fn head_dim(&self) -> usize { +// self.hidden_size / self.num_attention_heads +// } +// } #[derive(Debug, Clone)] pub struct RotaryEmbedding {