Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions candle-nn/src/var_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,22 @@ impl SimpleBackend for VarMap {
self.data().lock().unwrap().contains_key(name)
}
}
impl SimpleBackend for crate::var_map::ConcurrentVarMap {
fn get(
&self,
s: Shape,
name: &str,
h: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
self.get(s, name, h, dtype, dev)
}

fn contains_tensor(&self, name: &str) -> bool {
self.contains_key(name)
}
}

#[allow(dead_code)]
pub struct SafeTensorWithRouting<'a> {
Expand Down Expand Up @@ -466,6 +482,12 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
}

impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` using a custom backend.
///
/// It is preferred to use one of the more specific constructors. This
/// constructor is provided to allow downstream users to define their own
/// backends.

/// Initializes a `VarBuilder` using a custom backend.
///
/// It is preferred to use one of the more specific constructors. This
Expand Down
241 changes: 210 additions & 31 deletions candle-nn/src/var_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,165 @@
//!
use candle::{DType, Device, Result, Shape, Tensor, Var};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, RwLock};

/// Storage backend trait for VarMap - allows different synchronization strategies
pub trait VarStorage: Send + Sync + Clone {
fn new() -> Self;
fn get_var(&self, name: &str) -> Option<Var>;
fn all_vars(&self) -> Vec<Var>;
fn insert_var(&self, name: String, var: Var);
fn contains_key(&self, name: &str) -> bool;
fn len(&self) -> usize;
fn iter_for_save(&self) -> Vec<(String, Var)>;
fn iter_for_load(&self) -> Vec<(String, Var)>;
fn iter_mut_for_load(&self) -> Vec<(String, Var)>;
}

/// Original Mutex-based storage (for training)
#[derive(Clone)]
pub struct MutexStorage {
data: Arc<Mutex<HashMap<String, Var>>>,
}

/// New RwLock-based storage (for concurrent inference)
#[derive(Clone)]
pub struct RwLockStorage {
data: Arc<RwLock<HashMap<String, Var>>>,
}
// Implementation for existing Mutex storage - maintains exact original behavior
impl VarStorage for MutexStorage {
fn new() -> Self {
Self {
data: Arc::new(Mutex::new(HashMap::new())),
}
}

fn get_var(&self, name: &str) -> Option<Var> {
let data = self.data.lock().unwrap();
data.get(name).cloned()
}

fn all_vars(&self) -> Vec<Var> {
let data = self.data.lock().unwrap();
#[allow(clippy::map_clone)]
data.values().map(|c| c.clone()).collect::<Vec<_>>()
}

fn insert_var(&self, name: String, var: Var) {
let mut data = self.data.lock().unwrap();
data.insert(name, var);
}

fn contains_key(&self, name: &str) -> bool {
let data = self.data.lock().unwrap();
data.contains_key(name)
}

fn len(&self) -> usize {
let data = self.data.lock().unwrap();
data.len()
}

fn iter_for_save(&self) -> Vec<(String, Var)> {
let data = self.data.lock().unwrap();
data.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}

fn iter_for_load(&self) -> Vec<(String, Var)> {
let data = self.data.lock().unwrap();
data.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}

fn iter_mut_for_load(&self) -> Vec<(String, Var)> {
let data = self.data.lock().unwrap();
data.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}
}

// Implementation for RwLock storage
impl VarStorage for RwLockStorage {
fn new() -> Self {
Self {
data: Arc::new(RwLock::new(HashMap::new())),
}
}

fn get_var(&self, name: &str) -> Option<Var> {
let data = self.data.read().unwrap();
data.get(name).cloned()
}

fn all_vars(&self) -> Vec<Var> {
let data = self.data.read().unwrap();
#[allow(clippy::map_clone)]
data.values().map(|c| c.clone()).collect::<Vec<_>>()
}

fn insert_var(&self, name: String, var: Var) {
let mut data = self.data.write().unwrap();
data.insert(name, var);
}

fn contains_key(&self, name: &str) -> bool {
let data = self.data.read().unwrap();
data.contains_key(name)
}

fn len(&self) -> usize {
let data = self.data.read().unwrap();
data.len()
}

fn iter_for_save(&self) -> Vec<(String, Var)> {
let data = self.data.read().unwrap();
data.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}

fn iter_for_load(&self) -> Vec<(String, Var)> {
let data = self.data.read().unwrap();
data.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}

fn iter_mut_for_load(&self) -> Vec<(String, Var)> {
let data = self.data.read().unwrap();
data.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
}
}

// Generic VarMap implementation
#[derive(Clone)]
pub struct VarMapGeneric<Storage: VarStorage> {
storage: Storage,
}
// Type aliases for easy usage
/// A `VarMap` is a store that holds named variables. Variables can be retrieved from the stores
/// and new variables can be added by providing some initialization config in case they are
/// missing.
/// `VarMap` structures can be serialized in the safetensors format.
#[derive(Clone)]
pub struct VarMap {
data: Arc<Mutex<HashMap<String, Var>>>,
}
pub type VarMap = VarMapGeneric<MutexStorage>; // Original (for training)

impl VarMap {
/// Concurrent version of VarMap using RwLock for better read performance in inference scenarios
pub type ConcurrentVarMap = VarMapGeneric<RwLockStorage>;

impl<Storage: VarStorage> VarMapGeneric<Storage> {
/// Create a new empty `VarMap`.
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let data = Arc::new(Mutex::new(HashMap::new()));
Self { data }
Self {
storage: Storage::new(),
}
}

/// Retrieve all the variables currently stored in the map.
pub fn all_vars(&self) -> Vec<Var> {
let tensor_data = self.data.lock().unwrap();
#[allow(clippy::map_clone)]
tensor_data.values().map(|c| c.clone()).collect::<Vec<_>>()
self.storage.all_vars()
}

/// Save the map in the safetensors format.
pub fn save<P: AsRef<std::path::Path>>(&self, path: P) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
let data = tensor_data.iter().map(|(k, v)| (k, v.as_tensor()));
let data = self.storage.iter_for_save();
let data = data.iter().map(|(k, v)| (k, v.as_tensor()));
safetensors::tensor::serialize_to_file(data, &None, path.as_ref())?;
Ok(())
}
Expand All @@ -43,25 +172,25 @@ impl VarMap {
pub fn load<P: AsRef<std::path::Path>>(&mut self, path: P) -> Result<()> {
let path = path.as_ref();
let data = unsafe { candle::safetensors::MmapedSafetensors::new(path)? };
let mut tensor_data = self.data.lock().unwrap();
for (name, var) in tensor_data.iter_mut() {
let data = data.load(name, var.device())?;
if let Err(err) = var.set(&data) {
candle::bail!("error setting {name} using data from {path:?}: {err}",)
let vars = self.storage.iter_mut_for_load();

for (name, var) in vars {
let tensor_data = data.load(&name, var.device())?;
if let Err(err) = var.set(&tensor_data) {
candle::bail!("error setting {name} using data from {path:?}: {err}")
}
}
Ok(())
}

/// Set a named variable to some value.
pub fn set_one<K: AsRef<str>, V: AsRef<Tensor>>(&mut self, name: K, value: V) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
let name = name.as_ref();
match tensor_data.get(name) {
match self.storage.get_var(name) {
None => candle::bail!("cannot find {name} in VarMap"),
Some(var) => {
if let Err(err) = var.set(value.as_ref()) {
candle::bail!("error setting {name}: {err}",)
candle::bail!("error setting {name}: {err}")
}
}
}
Expand All @@ -76,14 +205,13 @@ impl VarMap {
&mut self,
iter: I,
) -> Result<()> {
let tensor_data = self.data.lock().unwrap();
for (name, value) in iter {
let name = name.as_ref();
match tensor_data.get(name) {
match self.storage.get_var(name) {
None => candle::bail!("cannot find {name} in VarMap"),
Some(var) => {
if let Err(err) = var.set(value.as_ref()) {
candle::bail!("error setting {name}: {err}",)
candle::bail!("error setting {name}: {err}")
}
}
}
Expand All @@ -101,21 +229,72 @@ impl VarMap {
device: &Device,
) -> Result<Tensor> {
let shape = shape.into();
let mut tensor_data = self.data.lock().unwrap();
if let Some(tensor) = tensor_data.get(path) {
let tensor_shape = tensor.shape();
if let Some(existing_var) = self.storage.get_var(path) {
let tensor_shape = existing_var.shape();
if &shape != tensor_shape {
candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}")
}
return Ok(tensor.as_tensor().clone());
return Ok(existing_var.as_tensor().clone());
}
let var = init.var(shape, dtype, device)?;
let tensor = var.as_tensor().clone();
tensor_data.insert(path.to_string(), var);
self.storage.insert_var(path.to_string(), var);
Ok(tensor)
}

pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
&self.data
/// Get a variable by name (method for compatibility).
pub fn get_var(&self, name: &str) -> Option<Var> {
self.storage.get_var(name)
}

/// Insert a new variable (method for compatibility).
pub fn insert(&self, name: String, var: Var) {
self.storage.insert_var(name, var);
}

/// Check if a variable exists (method for compatibility).
pub fn contains_key(&self, name: &str) -> bool {
self.storage.contains_key(name)
}

/// Convert to the other storage type (for migration)
pub fn into_concurrent(self) -> ConcurrentVarMap
where
Storage: VarStorage,
{
let concurrent = ConcurrentVarMap::new();

// Transfer all variables
for (name, var) in self.storage.iter_for_save() {
concurrent.insert(name, var);
}

concurrent
}
}

impl VarMap {
pub fn data(&self) -> &Arc<Mutex<HashMap<String, Var>>> {
&self.storage.data
}
}
impl ConcurrentVarMap {
pub fn read_data(&self) -> std::sync::RwLockReadGuard<HashMap<String, Var>> {
self.storage.data.read().unwrap()
}
pub fn write_data(&self) -> std::sync::RwLockWriteGuard<HashMap<String, Var>> {
self.storage.data.write().unwrap()
}

pub fn get_vars_batch(&self, names: &[&str]) -> HashMap<String, Var> {
let data = self.storage.data.read().unwrap();
names
.iter()
.filter_map(|&name| data.get(name).map(|v| (name.to_string(), v.clone())))
.collect()
}

pub fn data(&self) -> &Arc<RwLock<HashMap<String, Var>>> {
&self.storage.data
}
}
Loading