Skip to content

Commit

Permalink
Convert view_tensor to use the new native Tensors (#1439)
Browse files Browse the repository at this point in the history
* Convert view_tensor to use the new native Tensors

* Limit F16 errors to the tensor module
  • Loading branch information
jleibs authored Feb 28, 2023
1 parent 6be85c0 commit 5661806
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 89 deletions.
32 changes: 31 additions & 1 deletion crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use arrow2_convert::deserialize::ArrowDeserialize;
use arrow2_convert::field::ArrowField;
use arrow2_convert::{serialize::ArrowSerialize, ArrowDeserialize, ArrowField, ArrowSerialize};

use crate::TensorElement;
use crate::{msg_bundle::Component, ClassicTensor, TensorDataStore};
use crate::{TensorDataType, TensorElement};

pub trait TensorTrait {
fn id(&self) -> TensorId;
Expand All @@ -17,6 +17,7 @@ pub trait TensorTrait {
fn is_vector(&self) -> bool;
fn meaning(&self) -> TensorDataMeaning;
fn get(&self, index: &[u64]) -> Option<TensorElement>;
fn dtype(&self) -> TensorDataType;
}

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -390,6 +391,21 @@ impl TensorTrait for Tensor {
TensorData::JPEG(_) => None, // Too expensive to unpack here.
}
}

fn dtype(&self) -> TensorDataType {
match &self.data {
TensorData::U8(_) | TensorData::JPEG(_) => TensorDataType::U8,
TensorData::U16(_) => TensorDataType::U16,
TensorData::U32(_) => TensorDataType::U32,
TensorData::U64(_) => TensorDataType::U64,
TensorData::I8(_) => TensorDataType::I8,
TensorData::I16(_) => TensorDataType::I16,
TensorData::I32(_) => TensorDataType::I32,
TensorData::I64(_) => TensorDataType::I64,
TensorData::F32(_) => TensorDataType::F32,
TensorData::F64(_) => TensorDataType::F64,
}
}
}

impl Component for Tensor {
Expand All @@ -412,6 +428,11 @@ pub enum TensorCastError {

#[error("ndarray Array is not contiguous and in standard order")]
NotContiguousStdOrder,

#[error(
"tensors do not currently support f16 data (https://github.com/rerun-io/rerun/issues/854)"
)]
F16NotSupported,
}

impl From<&Tensor> for ClassicTensor {
Expand Down Expand Up @@ -554,6 +575,15 @@ tensor_type!(i64, I64);
tensor_type!(f32, F32);
tensor_type!(f64, F64);

// TODO(#854) Switch back to `tensor_type!` once we have F16 tensors
impl<'a> TryFrom<&'a Tensor> for ::ndarray::ArrayViewD<'a, half::f16> {
type Error = TensorCastError;

fn try_from(_: &'a Tensor) -> Result<Self, Self::Error> {
Err(TensorCastError::F16NotSupported)
}
}

// ----------------------------------------------------------------------------

#[cfg(feature = "image")]
Expand Down
10 changes: 5 additions & 5 deletions crates/re_log_types/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ impl component_types::TensorTrait for ClassicTensor {
fn get(&self, index: &[u64]) -> Option<TensorElement> {
self.get(index)
}

#[inline]
fn dtype(&self) -> TensorDataType {
self.dtype
}
}

impl ClassicTensor {
Expand Down Expand Up @@ -394,11 +399,6 @@ impl ClassicTensor {
self.shape.as_slice()
}

#[inline]
pub fn dtype(&self) -> TensorDataType {
self.dtype
}

#[inline]
pub fn meaning(&self) -> component_types::TensorDataMeaning {
self.meaning
Expand Down
5 changes: 4 additions & 1 deletion crates/re_tensor_ops/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
//! This is particularly helpful for performing slice-operations for
//! dimensionality reduction.
use re_log_types::{component_types, ClassicTensor, TensorDataStore, TensorDataTypeTrait};
use re_log_types::{
component_types::{self, TensorTrait},
ClassicTensor, TensorDataStore, TensorDataTypeTrait,
};

pub mod dimension_mapping;

Expand Down
40 changes: 19 additions & 21 deletions crates/re_viewer/src/misc/caches/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod mesh_cache;
mod tensor_image_cache;

use re_log_types::component_types;
use re_log_types::component_types::{self, TensorTrait};
pub use tensor_image_cache::{AsDynamicImage, TensorImageView};

/// Does memoization of different things for the immediate mode UI.
Expand Down Expand Up @@ -33,9 +33,9 @@ impl Caches {
tensor_stats.clear();
}

pub fn tensor_stats(&mut self, tensor: &re_log_types::ClassicTensor) -> &TensorStats {
pub fn tensor_stats(&mut self, tensor: &re_log_types::component_types::Tensor) -> &TensorStats {
self.tensor_stats
.entry(tensor.id())
.entry(tensor.tensor_id)
.or_insert_with(|| TensorStats::new(tensor))
}
}
Expand All @@ -45,11 +45,10 @@ pub struct TensorStats {
}

impl TensorStats {
fn new(tensor: &re_log_types::ClassicTensor) -> Self {
use re_log_types::TensorDataType;
use re_tensor_ops::as_ndarray;

fn new(tensor: &re_log_types::component_types::Tensor) -> Self {
use half::f16;
use ndarray::ArrayViewD;
use re_log_types::TensorDataType;

macro_rules! declare_tensor_range_int {
($name: ident, $typ: ty) => {
Expand Down Expand Up @@ -103,21 +102,20 @@ impl TensorStats {
}

let range = match tensor.dtype() {
TensorDataType::U8 => as_ndarray::<u8>(tensor).ok().map(tensor_range_u8),
TensorDataType::U16 => as_ndarray::<u16>(tensor).ok().map(tensor_range_u16),
TensorDataType::U32 => as_ndarray::<u32>(tensor).ok().map(tensor_range_u32),
TensorDataType::U64 => as_ndarray::<u64>(tensor).ok().map(tensor_range_u64),

TensorDataType::I8 => as_ndarray::<i8>(tensor).ok().map(tensor_range_i8),
TensorDataType::I16 => as_ndarray::<i16>(tensor).ok().map(tensor_range_i16),
TensorDataType::I32 => as_ndarray::<i32>(tensor).ok().map(tensor_range_i32),
TensorDataType::I64 => as_ndarray::<i64>(tensor).ok().map(tensor_range_i64),

TensorDataType::F16 => as_ndarray::<f16>(tensor).ok().map(tensor_range_f16),
TensorDataType::F32 => as_ndarray::<f32>(tensor).ok().map(tensor_range_f32),
TensorDataType::F64 => as_ndarray::<f64>(tensor).ok().map(tensor_range_f64),
TensorDataType::U8 => ArrayViewD::<u8>::try_from(tensor).map(tensor_range_u8),
TensorDataType::U16 => ArrayViewD::<u16>::try_from(tensor).map(tensor_range_u16),
TensorDataType::U32 => ArrayViewD::<u32>::try_from(tensor).map(tensor_range_u32),
TensorDataType::U64 => ArrayViewD::<u64>::try_from(tensor).map(tensor_range_u64),

TensorDataType::I8 => ArrayViewD::<i8>::try_from(tensor).map(tensor_range_i8),
TensorDataType::I16 => ArrayViewD::<i16>::try_from(tensor).map(tensor_range_i16),
TensorDataType::I32 => ArrayViewD::<i32>::try_from(tensor).map(tensor_range_i32),
TensorDataType::I64 => ArrayViewD::<i64>::try_from(tensor).map(tensor_range_i64),
TensorDataType::F16 => ArrayViewD::<f16>::try_from(tensor).map(tensor_range_f16),
TensorDataType::F32 => ArrayViewD::<f32>::try_from(tensor).map(tensor_range_f32),
TensorDataType::F64 => ArrayViewD::<f64>::try_from(tensor).map(tensor_range_f64),
};

Self { range }
Self { range: range.ok() }
}
}
2 changes: 1 addition & 1 deletion crates/re_viewer/src/misc/caches/tensor_image_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use egui::{Color32, ColorImage};
use egui_extras::RetainedImage;
use image::DynamicImage;
use re_log_types::{
component_types::{self, ClassId, TensorDataMeaning},
component_types::{self, ClassId, TensorDataMeaning, TensorTrait},
ClassicTensor, MsgId, TensorDataStore, TensorDataType,
};
use re_renderer::{
Expand Down
53 changes: 19 additions & 34 deletions crates/re_viewer/src/ui/data_ui/image.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use egui::Vec2;
use itertools::Itertools as _;

use re_log_types::{
component_types::{ClassId, TensorDataMeaning},
ClassicTensor,
};
use re_log_types::component_types::{ClassId, Tensor, TensorDataMeaning, TensorTrait};

use crate::misc::{
caches::{TensorImageView, TensorStats},
Expand All @@ -19,19 +16,7 @@ pub fn format_tensor_shape_single_line(
format!("[{}]", shape.iter().join(", "))
}

impl DataUi for re_log_types::component_types::Tensor {
fn data_ui(
&self,
ctx: &mut ViewerContext<'_>,
ui: &mut egui::Ui,
verbosity: UiVerbosity,
query: &re_arrow_store::LatestAtQuery,
) {
ClassicTensor::from(self).data_ui(ctx, ui, verbosity, query);
}
}

impl DataUi for ClassicTensor {
impl DataUi for Tensor {
fn data_ui(
&self,
ctx: &mut ViewerContext<'_>,
Expand Down Expand Up @@ -114,7 +99,7 @@ impl DataUi for ClassicTensor {
pub fn tensor_dtype_and_shape_ui_grid_contents(
re_ui: &re_ui::ReUi,
ui: &mut egui::Ui,
tensor: &ClassicTensor,
tensor: &Tensor,
tensor_stats: Option<&TensorStats>,
) {
re_ui
Expand Down Expand Up @@ -158,7 +143,7 @@ pub fn tensor_dtype_and_shape_ui_grid_contents(
pub fn tensor_dtype_and_shape_ui(
re_ui: &re_ui::ReUi,
ui: &mut egui::Ui,
tensor: &ClassicTensor,
tensor: &Tensor,
tensor_stats: Option<&TensorStats>,
) {
egui::Grid::new("tensor_dtype_and_shape_ui")
Expand Down Expand Up @@ -452,11 +437,13 @@ fn histogram_ui(ui: &mut egui::Ui, rgb_image: &image::RgbImage) -> egui::Respons
#[cfg(not(target_arch = "wasm32"))]
fn image_options(
ui: &mut egui::Ui,
tensor: &re_log_types::ClassicTensor,
tensor: &re_log_types::component_types::Tensor,
dynamic_image: &image::DynamicImage,
) {
// TODO(emilk): support copying images on web

use re_log_types::component_types::TensorData;

#[cfg(not(target_arch = "wasm32"))]
if ui.button("Click to copy image").clicked() {
let rgba = dynamic_image.to_rgba8();
Expand All @@ -471,39 +458,37 @@ fn image_options(
// TODO(emilk): support saving images on web
#[cfg(not(target_arch = "wasm32"))]
if ui.button("Save image…").clicked() {
use re_log_types::TensorDataStore;

match &tensor.data {
TensorDataStore::Dense(_) => {
TensorData::JPEG(bytes) => {
if let Some(path) = rfd::FileDialog::new()
.set_file_name("image.png")
.set_file_name("image.jpg")
.save_file()
{
match dynamic_image.save(&path) {
// TODO(emilk): show a popup instead of logging result
match write_binary(&path, bytes.as_slice()) {
Ok(()) => {
re_log::info!("Image saved to {path:?}");
}
Err(err) => {
re_log::error!("Failed saving image to {path:?}: {err}");
re_log::error!(
"Failed saving image to {path:?}: {}",
re_error::format(&err)
);
}
}
}
}
TensorDataStore::Jpeg(bytes) => {
_ => {
if let Some(path) = rfd::FileDialog::new()
.set_file_name("image.jpg")
.set_file_name("image.png")
.save_file()
{
match write_binary(&path, bytes) {
match dynamic_image.save(&path) {
// TODO(emilk): show a popup instead of logging result
Ok(()) => {
re_log::info!("Image saved to {path:?}");
}
Err(err) => {
re_log::error!(
"Failed saving image to {path:?}: {}",
re_error::format(&err)
);
re_log::error!("Failed saving image to {path:?}: {err}");
}
}
}
Expand Down
8 changes: 2 additions & 6 deletions crates/re_viewer/src/ui/view_tensor/scene.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use re_arrow_store::LatestAtQuery;
use re_data_store::{EntityPath, EntityProperties, InstancePath};
use re_log_types::{
component_types::{InstanceKey, Tensor},
ClassicTensor,
};
use re_log_types::component_types::{InstanceKey, Tensor, TensorTrait};
use re_query::{query_entity_with_primary, EntityView, QueryError};

use crate::{misc::ViewerContext, ui::SceneQuery};
Expand All @@ -13,7 +10,7 @@ use crate::{misc::ViewerContext, ui::SceneQuery};
/// A tensor scene, with everything needed to render it.
#[derive(Default)]
pub struct SceneTensor {
pub tensors: std::collections::BTreeMap<InstancePath, ClassicTensor>,
pub tensors: std::collections::BTreeMap<InstancePath, Tensor>,
}

impl SceneTensor {
Expand Down Expand Up @@ -47,7 +44,6 @@ impl SceneTensor {
entity_view: &EntityView<Tensor>,
) -> Result<(), QueryError> {
entity_view.visit1(|instance_key: InstanceKey, tensor: Tensor| {
let tensor = ClassicTensor::from(&tensor);
if !tensor.is_shaped_like_an_image() {
let instance_path = InstancePath::instance(ent_path.clone(), instance_key);
self.tensors.insert(instance_path, tensor);
Expand Down
Loading

0 comments on commit 5661806

Please sign in to comment.