Skip to content

Commit

Permalink
half:add auto dtype for to_vec0..3
Browse files Browse the repository at this point in the history
  • Loading branch information
haricot committed Jan 23, 2025
1 parent b484829 commit ec25c81
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 16 deletions.
7 changes: 7 additions & 0 deletions candle-core/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ impl DType {
Self::BF16 | Self::F16 | Self::F32 | Self::F64 => true,
}
}

pub fn is_half_to<S: crate::WithDType>(&self) -> bool {
match (self, S::DTYPE) {
(DType::BF16 | DType::F16, DType::F32 | DType::U32 | DType::U8) => true,
_ => false,
}
}
}

pub trait WithDType:
Expand Down
51 changes: 35 additions & 16 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,10 @@ impl Tensor {

/// An alias for `to_scalar`.
pub fn to_vec0<S: crate::WithDType>(&self) -> Result<S> {
self.to_scalar::<S>()
match self.dtype().is_half_to::<S>() {
true => self.to_dtype(S::DTYPE)?.to_scalar::<S>(),
_ => self.to_scalar::<S>(),
}
}

/// Repeat this tensor along the specified dimensions.
Expand Down Expand Up @@ -1604,23 +1607,29 @@ impl Tensor {

/// Returns the data contained in a 1D tensor as a vector of scalar values.
pub fn to_vec1<S: crate::WithDType>(&self) -> Result<Vec<S>> {
if self.rank() != 1 {
Err(Error::UnexpectedNumberOfDims {
let tensor = match self.dtype().is_half_to::<S>() {
true => &self.to_dtype(S::DTYPE)?,
false => &self,
};

if tensor.rank() != 1 {
return Err(Error::UnexpectedNumberOfDims {
expected: 1,
got: self.rank(),
shape: self.shape().clone(),
got: tensor.rank(),
shape: tensor.shape().clone(),
}
.bt())?
.bt());
}
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
let data = match self.layout.contiguous_offsets() {
let data = match tensor.layout().contiguous_offsets() {
Some((o1, o2)) => data[o1..o2].to_vec(),
None => self.strided_index().map(|i| data[i]).collect(),
None => tensor.strided_index().map(|i| data[i]).collect(),
};
Ok::<Vec<_>, Error>(data)
};
match &*self.storage() {
let storage = tensor.storage();
match &*storage {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Expand All @@ -1629,19 +1638,23 @@ impl Tensor {

/// Returns the data contained in a 2D tensor as a vector of vector of scalar values.
pub fn to_vec2<S: crate::WithDType>(&self) -> Result<Vec<Vec<S>>> {
let (dim1, dim2) = self.dims2()?;
let tensor = match self.dtype().is_half_to::<S>() {
true => &self.to_dtype(S::DTYPE)?,
false => &self,
};
let (dim1, dim2) = tensor.dims2()?;
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
let mut rows = vec![];
match self.layout.contiguous_offsets() {
match tensor.layout.contiguous_offsets() {
Some((o1, o2)) => {
let data = &data[o1..o2];
for idx_row in 0..dim1 {
rows.push(data[idx_row * dim2..(idx_row + 1) * dim2].to_vec())
}
}
None => {
let mut src_index = self.strided_index();
let mut src_index = tensor.strided_index();
for _idx_row in 0..dim1 {
let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect();
rows.push(row)
Expand All @@ -1651,7 +1664,8 @@ impl Tensor {
}
Ok(rows)
};
match &*self.storage() {
let storage = tensor.storage();
match &*storage {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Expand All @@ -1660,7 +1674,11 @@ impl Tensor {

/// Returns the data contained in a 3D tensor.
pub fn to_vec3<S: crate::WithDType>(&self) -> Result<Vec<Vec<Vec<S>>>> {
let (dim1, dim2, dim3) = self.dims3()?;
let tensor = match self.dtype().is_half_to::<S>() {
true => &self.to_dtype(S::DTYPE)?,
false => &self,
};
let (dim1, dim2, dim3) = tensor.dims3()?;
let from_cpu_storage = |cpu_storage: &crate::CpuStorage| {
let data = S::cpu_storage_as_slice(cpu_storage)?;
let mut top_rows = vec![];
Expand All @@ -1678,7 +1696,7 @@ impl Tensor {
}
}
None => {
let mut src_index = self.strided_index();
let mut src_index = tensor.strided_index();
for _idx in 0..dim1 {
let mut rows = vec![];
for _jdx in 0..dim2 {
Expand All @@ -1692,7 +1710,8 @@ impl Tensor {
}
Ok(top_rows)
};
match &*self.storage() {
let storage = tensor.storage();
match &*storage {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Expand Down

0 comments on commit ec25c81

Please sign in to comment.