Skip to content

Commit

Permalink
improve image tensor conversions (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba authored Sep 28, 2024
1 parent a4b5fd4 commit 563ee3e
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 48 deletions.
94 changes: 48 additions & 46 deletions crates/kornia-core/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,54 @@ where
}
}

/// Creates a new `TensorStorage` from a slice of data.
///
/// # Arguments
///
/// * `data` - A slice containing the data to be stored.
/// * `alloc` - The allocator to use for creating the storage.
///
/// # Returns
///
/// A new `TensorStorage` instance containing a copy of the input data.
///
/// # Errors
///
/// Returns a `TensorAllocatorError` if the allocation fails.
pub fn from_slice(data: &[T], alloc: A) -> Self {
let buffer = Buffer::from_slice_ref(data);
Self {
data: buffer.into(),
alloc: Arc::new(alloc),
}
}

/// Creates a new tensor storage from an existing raw pointer with the given allocator.
///
/// # Arguments
///
/// * `ptr` - The existing raw pointer to the tensor data.
/// * `len` - The number of elements in the tensor storage.
/// * `alloc` - A reference to the allocator used to allocate the tensor storage.
///
/// # Safety
///
/// The pointer must be properly aligned and have the correct length.
pub unsafe fn from_ptr(ptr: *mut T, len: usize, alloc: &A) -> Self {
// create the buffer
let buffer = Buffer::from_custom_allocation(
NonNull::new_unchecked(ptr as *mut u8),
len * std::mem::size_of::<T>(),
Arc::new(()),
);

// create tensor storage
Self {
data: buffer.into(),
alloc: Arc::new(alloc.clone()),
}
}

/// Returns the allocator used to allocate the tensor storage.
#[inline]
pub fn alloc(&self) -> &A {
Expand Down Expand Up @@ -213,52 +261,6 @@ where
pub fn get_unchecked(&self, index: usize) -> &T {
unsafe { self.data.get_unchecked(index) }
}

/// Creates a new `TensorStorage` from a slice of data.
///
/// # Arguments
///
/// * `data` - A slice containing the data to be stored.
/// * `alloc` - The allocator to use for creating the storage.
///
/// # Returns
///
/// A new `TensorStorage` instance containing a copy of the input data.
///
/// # Errors
///
/// Returns a `TensorAllocatorError` if the allocation fails.
pub fn from_slice(data: &[T], alloc: A) -> Result<Self, TensorAllocatorError> {
let mut storage = Self::new(data.len(), alloc)?;
storage.as_mut_slice().copy_from_slice(data);
Ok(storage)
}

/// Creates a new tensor storage from an existing raw pointer with the given allocator.
///
/// # Arguments
///
/// * `ptr` - The existing raw pointer to the tensor data.
/// * `len` - The number of elements in the tensor storage.
/// * `alloc` - A reference to the allocator used to allocate the tensor storage.
///
/// # Safety
///
/// The pointer must be properly aligned and have the correct length.
pub unsafe fn from_ptr(ptr: *mut T, len: usize, alloc: &A) -> Self {
// create the buffer
let buffer = Buffer::from_custom_allocation(
NonNull::new_unchecked(ptr as *mut u8),
len * std::mem::size_of::<T>(),
Arc::new(()),
);

// create tensor storage
Self {
data: buffer.into(),
alloc: Arc::new(alloc.clone()),
}
}
}

/// A new `TensorStorage` instance with cloned data if successful, otherwise an error.
Expand Down
2 changes: 1 addition & 1 deletion crates/kornia-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ where
if numel != data.len() {
return Err(TensorError::InvalidShape(numel));
}
let storage = TensorStorage::from_slice(data, alloc)?;
let storage = TensorStorage::from_slice(data, alloc);
let strides = get_strides_from_shape(shape);
Ok(Self {
storage,
Expand Down
107 changes: 106 additions & 1 deletion crates/kornia-image/src/image.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use kornia_core::{CpuAllocator, SafeTensorType, Tensor3};
use kornia_core::{CpuAllocator, SafeTensorType, Tensor, Tensor2, Tensor3};

use crate::error::ImageError;

Expand Down Expand Up @@ -177,6 +177,27 @@ where
Ok(image)
}

/// Create a new image from a slice of pixel data.
///
/// # Arguments
///
/// * `size` - The size of the image in pixels.
/// * `data` - A slice containing the pixel data.
///
/// # Returns
///
/// A new image created from the given size and pixel data.
///
/// # Errors
///
/// Returns an error if the length of the data slice doesn't match the image dimensions,
/// or if there's an issue creating the tensor or image.
pub fn from_size_slice(size: ImageSize, data: &[T]) -> Result<Self, ImageError> {
let tensor: Tensor3<T> =
Tensor::from_shape_slice([size.height, size.width, C], data, CpuAllocator)?;
Image::try_from(tensor)
}

/// Cast the pixel data of the image to a different type.
///
/// # Returns
Expand Down Expand Up @@ -414,9 +435,54 @@ where
}
}

/// helper to convert an single channel tensor to a kornia image with try into
impl<T> TryFrom<Tensor2<T>> for Image<T, 1>
where
T: SafeTensorType,
{
type Error = ImageError;

fn try_from(value: Tensor2<T>) -> Result<Self, Self::Error> {
Self::from_size_slice(
ImageSize {
width: value.shape[1],
height: value.shape[0],
},
value.as_slice(),
)
}
}

/// helper to convert an multi channel tensor to a kornia image with try into
impl<T, const C: usize> TryFrom<Tensor3<T>> for Image<T, C>
where
T: SafeTensorType,
{
type Error = ImageError;

fn try_from(value: Tensor3<T>) -> Result<Self, Self::Error> {
if value.shape[2] != C {
return Err(ImageError::InvalidChannelShape(value.shape[2], C));
}
Ok(Self(value))
}
}

impl<T, const C: usize> TryInto<Tensor3<T>> for Image<T, C>
where
T: SafeTensorType,
{
type Error = ImageError;

fn try_into(self) -> Result<Tensor3<T>, Self::Error> {
Ok(self.0)
}
}

#[cfg(test)]
mod tests {
use crate::image::{Image, ImageError, ImageSize};
use kornia_core::{CpuAllocator, Tensor};

#[test]
fn image_size() {
Expand Down Expand Up @@ -560,4 +626,43 @@ mod tests {

Ok(())
}

#[test]
fn image_from_tensor() -> Result<(), ImageError> {
let tensor =
Tensor::<u8, 2, CpuAllocator>::from_shape_vec([2, 3], vec![0u8; 2 * 3], CpuAllocator)?;

let image = Image::<u8, 1>::try_from(tensor.clone())?;
assert_eq!(image.size().width, 3);
assert_eq!(image.size().height, 2);
assert_eq!(image.num_channels(), 1);

let image_2: Image<u8, 1> = tensor.try_into()?;
assert_eq!(image_2.size().width, 3);
assert_eq!(image_2.size().height, 2);
assert_eq!(image_2.num_channels(), 1);

Ok(())
}

#[test]
fn image_from_tensor_3d() -> Result<(), ImageError> {
let tensor = Tensor::<u8, 3, CpuAllocator>::from_shape_vec(
[2, 3, 4],
vec![0u8; 2 * 3 * 4],
CpuAllocator,
)?;

let image = Image::<u8, 4>::try_from(tensor.clone())?;
assert_eq!(image.size().width, 3);
assert_eq!(image.size().height, 2);
assert_eq!(image.num_channels(), 4);

let image_2: Image<u8, 4> = tensor.try_into()?;
assert_eq!(image_2.size().width, 3);
assert_eq!(image_2.size().height, 2);
assert_eq!(image_2.num_channels(), 4);

Ok(())
}
}

0 comments on commit 563ee3e

Please sign in to comment.