Skip to content

Commit

Permalink
Delete ClassicTensor and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jleibs committed Feb 28, 2023
1 parent 3acfcf3 commit 31d2e5b
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 599 deletions.
4 changes: 3 additions & 1 deletion crates/re_log_types/src/component_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ pub use radius::Radius;
pub use rect::Rect2D;
pub use scalar::{Scalar, ScalarPlotProps};
pub use size::Size3D;
pub use tensor::{Tensor, TensorData, TensorDataMeaning, TensorDimension, TensorId, TensorTrait};
pub use tensor::{
Tensor, TensorCastError, TensorData, TensorDataMeaning, TensorDimension, TensorId, TensorTrait,
};
pub use text_entry::TextEntry;
pub use transform::{Pinhole, Rigid3, Transform};
pub use vec::{Vec2D, Vec3D, Vec4D};
Expand Down
104 changes: 36 additions & 68 deletions crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::sync::Arc;

use arrow2::array::{FixedSizeBinaryArray, MutableFixedSizeBinaryArray};
use arrow2::buffer::Buffer;
use arrow2_convert::deserialize::ArrowDeserialize;
use arrow2_convert::field::ArrowField;
use arrow2_convert::{serialize::ArrowSerialize, ArrowDeserialize, ArrowField, ArrowSerialize};

use crate::{msg_bundle::Component, ClassicTensor, TensorDataStore};
use crate::msg_bundle::Component;
use crate::{TensorDataType, TensorElement};

pub trait TensorTrait {
Expand Down Expand Up @@ -415,7 +413,7 @@ impl Component for Tensor {
}
}

#[derive(thiserror::Error, Debug)]
#[derive(thiserror::Error, Debug, PartialEq)]
pub enum TensorCastError {
#[error("ndarray type mismatch with tensor storage")]
TypeMismatch,
Expand All @@ -435,65 +433,6 @@ pub enum TensorCastError {
F16NotSupported,
}

impl From<&Tensor> for ClassicTensor {
fn from(value: &Tensor) -> Self {
let (dtype, data) = match &value.data {
TensorData::U8(data) => (
crate::TensorDataType::U8,
TensorDataStore::Dense(Arc::from(data.as_slice())),
),
TensorData::U16(data) => (
crate::TensorDataType::U16,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::U32(data) => (
crate::TensorDataType::U32,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::U64(data) => (
crate::TensorDataType::U64,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::I8(data) => (
crate::TensorDataType::I8,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::I16(data) => (
crate::TensorDataType::I16,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::I32(data) => (
crate::TensorDataType::I32,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::I64(data) => (
crate::TensorDataType::I64,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::F32(data) => (
crate::TensorDataType::F32,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::F64(data) => (
crate::TensorDataType::F64,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::JPEG(data) => (
crate::TensorDataType::U8,
TensorDataStore::Jpeg(Arc::from(data.as_slice())),
),
};

ClassicTensor::new(
value.tensor_id,
value.shape.clone(),
dtype,
value.meaning,
data,
)
}
}

macro_rules! tensor_type {
($type:ty, $variant:ident) => {
impl<'a> TryFrom<&'a Tensor> for ::ndarray::ArrayViewD<'a, $type> {
Expand Down Expand Up @@ -523,15 +462,23 @@ macro_rules! tensor_type {
name: None,
})
.collect();
view.to_slice()
.ok_or(TensorCastError::NotContiguousStdOrder)
.map(|slice| Tensor {

match view.to_slice() {
Some(slice) => Ok(Tensor {
tensor_id: TensorId::random(),
shape,
data: TensorData::$variant(Vec::from(slice).into()),
meaning: TensorDataMeaning::Unknown,
meter: None,
})
}),
None => Ok(Tensor {
tensor_id: TensorId::random(),
shape,
data: TensorData::$variant(view.iter().cloned().collect::<Vec<_>>().into()),
meaning: TensorDataMeaning::Unknown,
meter: None,
}),
}
}
}

Expand Down Expand Up @@ -602,6 +549,24 @@ pub enum ImageError {
ReadError(#[from] std::io::Error),
}

impl Tensor {
pub fn new(
tensor_id: TensorId,
shape: Vec<TensorDimension>,
data: TensorData,
meaning: TensorDataMeaning,
meter: Option<f32>,
) -> Self {
Self {
tensor_id,
shape,
data,
meaning,
meter,
}
}
}

#[cfg(feature = "image")]
impl Tensor {
/// Construct a tensor from the contents of a JPEG file on disk.
Expand Down Expand Up @@ -708,6 +673,7 @@ impl Tensor {
#[test]
fn test_ndarray() {
let t0 = Tensor {
tensor_id: TensorId::random(),
shape: vec![
TensorDimension {
size: 2,
Expand All @@ -718,7 +684,9 @@ fn test_ndarray() {
name: None,
},
],
data: TensorData::U16(vec![1, 2, 3, 4]),
data: TensorData::U16(vec![1, 2, 3, 4].into()),
meaning: TensorDataMeaning::Unknown,
meter: None,
};
let a0: ndarray::ArrayViewD<'_, u16> = (&t0).try_into().unwrap();
dbg!(a0); // NOLINT
Expand Down
Loading

0 comments on commit 31d2e5b

Please sign in to comment.