Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 29 additions & 52 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,6 @@ use mistralrs_quant::ShardedVarBuilder;
use serde::Deserialize;
use tracing::info;

fn split_range(range: std::ops::Range<usize>, n: usize) -> Vec<std::ops::Range<usize>> {
assert!(n > 0, "n must be non-zero");

let total = range.end - range.start;
let chunk_size = total / n;
let remainder = total % n;

let mut chunks = Vec::with_capacity(n);
let mut start = range.start;

// Create each chunk. The first `remainder` chunks get an extra element.
for i in 0..n {
let extra = if i < remainder { 1 } else { 0 };
let end = start + chunk_size + extra;
chunks.push(start..end);
start = end;
}

chunks
}

#[derive(Debug, Default, Deserialize, Clone)]
pub struct DeviceLayerMapMetadata {
pub ordinal: usize,
Expand All @@ -44,11 +23,11 @@ pub enum DeviceMapSetting {
/// Automatic device mapping (recommended).
Auto(AutoDeviceMapParams),
/// Dummy device mapping for a NCCL pipeline
Nccl { devices: Vec<Device> },
/// Device mapping when using PP (agnostic of TP)
NcclPipelineParallel {
devices_and_comms: Vec<(Arc<mistralrs_quant::Comm>, Device)>,
DummyNccl { nm_device: Device },
/// Real device mapping for a NCCL pipeline
Nccl {
nm_device: Device,
comm: Arc<mistralrs_quant::Comm>,
},
}

Expand Down Expand Up @@ -87,27 +66,21 @@ impl DeviceMapSetting {
topology: Option<&Topology>,
) -> Result<Box<dyn DeviceMapper + Send + Sync>> {
match self {
Self::Nccl { devices } => {
Self::Nccl { nm_device, comm } => {
once_log_info("Loading model using a NCCL-parallelized pipeline.");
Ok(Box::new(NcclDeviceMapper {
nm_device: devices[0].clone(),
devices: devices.clone(),
nm_device: nm_device.clone(),
model_layers,
comm: Some(comm.clone()),
}))
}
Self::NcclPipelineParallel {
devices_and_comms,
nm_device,
} => {
let splits = split_range(0..model_layers, devices_and_comms.len());

let mut mappings = Vec::new();
for (split, device) in splits.into_iter().zip(devices_and_comms) {
mappings.extend(vec![device.clone(); split.len()]);
}

Ok(Box::new(NcclPipelineParallelMapper {
mappings,
Self::DummyNccl { nm_device } => {
once_log_info("Loading model using a NCCL-parallelized pipeline.");
Ok(Box::new(NcclDeviceMapper {
nm_device: nm_device.clone(),
model_layers,
comm: None,
}))
}

Expand Down Expand Up @@ -422,7 +395,8 @@ impl DeviceMapper for DummyDeviceMapper {
#[derive(Debug)]
pub struct NcclDeviceMapper {
nm_device: Device,
devices: Vec<Device>,
model_layers: usize,
comm: Option<Arc<mistralrs_quant::Comm>>,
}

impl DeviceMapper for NcclDeviceMapper {
Expand All @@ -445,7 +419,7 @@ impl DeviceMapper for NcclDeviceMapper {
Some(&self.nm_device)
}
fn get_unique_devices(&self) -> Vec<Device> {
self.devices.clone()
vec![self.nm_device.clone()]
}
fn cast_nm_device(&self, x: &Tensor, loading_isq: bool) -> Result<Tensor> {
if loading_isq {
Expand All @@ -467,21 +441,24 @@ impl DeviceMapper for NcclDeviceMapper {
}
fn get_min_dtype(&self, dtype: &dyn TryIntoDType) -> Result<DType> {
dtype
.try_into_dtype(&self.devices.iter().collect::<Vec<_>>())
.try_into_dtype(&[&self.nm_device])
.map_err(candle_core::Error::msg)
}
fn num_device_mapping_layers(&self) -> usize {
// Effectively one layer
1
self.model_layers
}
fn get_comm_for(&self, layer_idx: usize) -> Result<Arc<mistralrs_quant::Comm>> {
let id = mistralrs_quant::Id::new();
Ok(Arc::new(mistralrs_quant::Comm::from_device(
id,
self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
0,
1,
)?))
if let Some(comm) = &self.comm {
Ok(comm.clone())
} else {
let id = mistralrs_quant::Id::new();
Ok(Arc::new(mistralrs_quant::Comm::from_device(
id,
self.device_for(layer_idx, false).unwrap_or(&self.nm_device),
0,
1,
)?))
}
}
}

Expand Down
10 changes: 5 additions & 5 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ impl Loader for NormalLoader {

// If auto, convert to Map if not using nccl
if use_nccl {
mapper = DeviceMapSetting::Nccl {
devices: available_devices.clone(),
mapper = DeviceMapSetting::DummyNccl {
nm_device: available_devices[0].clone(),
};
} else if let DeviceMapSetting::Auto(params) = mapper.clone() {
// Initial dtype
Expand Down Expand Up @@ -626,9 +626,9 @@ impl Loader for NormalLoader {
info!("Loading all ranks.");
let device = available_devices[0].clone();
// The mapper is specific to this pipeline
let mapper = DeviceMapSetting::NcclPipelineParallel {
devices_and_comms: vec![(Arc::new(comm), device.clone())],
nm_device: device.clone(),
let mapper = DeviceMapSetting::Nccl {
nm_device: available_devices[0].clone(),
comm: Arc::new(comm),
}
.into_mapper(
self.inner.get_total_device_mapping_num_layers(&config)?,
Expand Down
Loading