Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert view_tensor to use the new native Tensors #1439

Merged
merged 2 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use arrow2_convert::deserialize::ArrowDeserialize;
use arrow2_convert::field::ArrowField;
use arrow2_convert::{serialize::ArrowSerialize, ArrowDeserialize, ArrowField, ArrowSerialize};

use crate::TensorElement;
use crate::{msg_bundle::Component, ClassicTensor, TensorDataStore};
use crate::{TensorDataType, TensorElement};

pub trait TensorTrait {
fn id(&self) -> TensorId;
Expand All @@ -17,6 +17,7 @@ pub trait TensorTrait {
fn is_vector(&self) -> bool;
fn meaning(&self) -> TensorDataMeaning;
fn get(&self, index: &[u64]) -> Option<TensorElement>;
fn dtype(&self) -> TensorDataType;
}

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -390,6 +391,21 @@ impl TensorTrait for Tensor {
TensorData::JPEG(_) => None, // Too expensive to unpack here.
}
}

fn dtype(&self) -> TensorDataType {
match &self.data {
TensorData::U8(_) | TensorData::JPEG(_) => TensorDataType::U8,
TensorData::U16(_) => TensorDataType::U16,
TensorData::U32(_) => TensorDataType::U32,
TensorData::U64(_) => TensorDataType::U64,
TensorData::I8(_) => TensorDataType::I8,
TensorData::I16(_) => TensorDataType::I16,
TensorData::I32(_) => TensorDataType::I32,
TensorData::I64(_) => TensorDataType::I64,
TensorData::F32(_) => TensorDataType::F32,
TensorData::F64(_) => TensorDataType::F64,
}
}
}

impl Component for Tensor {
Expand All @@ -412,6 +428,11 @@ pub enum TensorCastError {

#[error("ndarray Array is not contiguous and in standard order")]
NotContiguousStdOrder,

#[error(
"tensors do not currently support f16 data (https://github.com/rerun-io/rerun/issues/854)"
)]
F16NotSupported,
}

impl From<&Tensor> for ClassicTensor {
Expand Down Expand Up @@ -554,6 +575,15 @@ tensor_type!(i64, I64);
tensor_type!(f32, F32);
tensor_type!(f64, F64);

// TODO(#854) Switch back to `tensor_type!` once we have F16 tensors
impl<'a> TryFrom<&'a Tensor> for ::ndarray::ArrayViewD<'a, half::f16> {
type Error = TensorCastError;

fn try_from(_: &'a Tensor) -> Result<Self, Self::Error> {
Err(TensorCastError::F16NotSupported)
}
}

// ----------------------------------------------------------------------------

#[cfg(feature = "image")]
Expand Down
10 changes: 5 additions & 5 deletions crates/re_log_types/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ impl component_types::TensorTrait for ClassicTensor {
fn get(&self, index: &[u64]) -> Option<TensorElement> {
self.get(index)
}

#[inline]
fn dtype(&self) -> TensorDataType {
self.dtype
}
}

impl ClassicTensor {
Expand Down Expand Up @@ -394,11 +399,6 @@ impl ClassicTensor {
self.shape.as_slice()
}

#[inline]
pub fn dtype(&self) -> TensorDataType {
self.dtype
}

#[inline]
pub fn meaning(&self) -> component_types::TensorDataMeaning {
self.meaning
Expand Down
5 changes: 4 additions & 1 deletion crates/re_tensor_ops/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
//! This is particularly helpful for performing slice-operations for
//! dimensionality reduction.

use re_log_types::{component_types, ClassicTensor, TensorDataStore, TensorDataTypeTrait};
use re_log_types::{
component_types::{self, TensorTrait},
ClassicTensor, TensorDataStore, TensorDataTypeTrait,
};

pub mod dimension_mapping;

Expand Down
40 changes: 19 additions & 21 deletions crates/re_viewer/src/misc/caches/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod mesh_cache;
mod tensor_image_cache;

use re_log_types::component_types;
use re_log_types::component_types::{self, TensorTrait};
pub use tensor_image_cache::{AsDynamicImage, TensorImageView};

/// Does memoization of different things for the immediate mode UI.
Expand Down Expand Up @@ -33,9 +33,9 @@ impl Caches {
tensor_stats.clear();
}

pub fn tensor_stats(&mut self, tensor: &re_log_types::ClassicTensor) -> &TensorStats {
pub fn tensor_stats(&mut self, tensor: &re_log_types::component_types::Tensor) -> &TensorStats {
self.tensor_stats
.entry(tensor.id())
.entry(tensor.tensor_id)
.or_insert_with(|| TensorStats::new(tensor))
}
}
Expand All @@ -45,11 +45,10 @@ pub struct TensorStats {
}

impl TensorStats {
fn new(tensor: &re_log_types::ClassicTensor) -> Self {
use re_log_types::TensorDataType;
use re_tensor_ops::as_ndarray;

fn new(tensor: &re_log_types::component_types::Tensor) -> Self {
use half::f16;
use ndarray::ArrayViewD;
use re_log_types::TensorDataType;

macro_rules! declare_tensor_range_int {
($name: ident, $typ: ty) => {
Expand Down Expand Up @@ -103,21 +102,20 @@ impl TensorStats {
}

let range = match tensor.dtype() {
TensorDataType::U8 => as_ndarray::<u8>(tensor).ok().map(tensor_range_u8),
TensorDataType::U16 => as_ndarray::<u16>(tensor).ok().map(tensor_range_u16),
TensorDataType::U32 => as_ndarray::<u32>(tensor).ok().map(tensor_range_u32),
TensorDataType::U64 => as_ndarray::<u64>(tensor).ok().map(tensor_range_u64),

TensorDataType::I8 => as_ndarray::<i8>(tensor).ok().map(tensor_range_i8),
TensorDataType::I16 => as_ndarray::<i16>(tensor).ok().map(tensor_range_i16),
TensorDataType::I32 => as_ndarray::<i32>(tensor).ok().map(tensor_range_i32),
TensorDataType::I64 => as_ndarray::<i64>(tensor).ok().map(tensor_range_i64),

TensorDataType::F16 => as_ndarray::<f16>(tensor).ok().map(tensor_range_f16),
TensorDataType::F32 => as_ndarray::<f32>(tensor).ok().map(tensor_range_f32),
TensorDataType::F64 => as_ndarray::<f64>(tensor).ok().map(tensor_range_f64),
TensorDataType::U8 => ArrayViewD::<u8>::try_from(tensor).map(tensor_range_u8),
TensorDataType::U16 => ArrayViewD::<u16>::try_from(tensor).map(tensor_range_u16),
TensorDataType::U32 => ArrayViewD::<u32>::try_from(tensor).map(tensor_range_u32),
TensorDataType::U64 => ArrayViewD::<u64>::try_from(tensor).map(tensor_range_u64),

TensorDataType::I8 => ArrayViewD::<i8>::try_from(tensor).map(tensor_range_i8),
TensorDataType::I16 => ArrayViewD::<i16>::try_from(tensor).map(tensor_range_i16),
TensorDataType::I32 => ArrayViewD::<i32>::try_from(tensor).map(tensor_range_i32),
TensorDataType::I64 => ArrayViewD::<i64>::try_from(tensor).map(tensor_range_i64),
TensorDataType::F16 => ArrayViewD::<f16>::try_from(tensor).map(tensor_range_f16),
TensorDataType::F32 => ArrayViewD::<f32>::try_from(tensor).map(tensor_range_f32),
TensorDataType::F64 => ArrayViewD::<f64>::try_from(tensor).map(tensor_range_f64),
};

Self { range }
Self { range: range.ok() }
}
}
2 changes: 1 addition & 1 deletion crates/re_viewer/src/misc/caches/tensor_image_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use egui::{Color32, ColorImage};
use egui_extras::RetainedImage;
use image::DynamicImage;
use re_log_types::{
component_types::{self, ClassId, TensorDataMeaning},
component_types::{self, ClassId, TensorDataMeaning, TensorTrait},
ClassicTensor, MsgId, TensorDataStore, TensorDataType,
};
use re_renderer::{
Expand Down
53 changes: 19 additions & 34 deletions crates/re_viewer/src/ui/data_ui/image.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use egui::Vec2;
use itertools::Itertools as _;

use re_log_types::{
component_types::{ClassId, TensorDataMeaning},
ClassicTensor,
};
use re_log_types::component_types::{ClassId, Tensor, TensorDataMeaning, TensorTrait};

use crate::misc::{
caches::{TensorImageView, TensorStats},
Expand All @@ -19,19 +16,7 @@ pub fn format_tensor_shape_single_line(
format!("[{}]", shape.iter().join(", "))
}

impl DataUi for re_log_types::component_types::Tensor {
fn data_ui(
&self,
ctx: &mut ViewerContext<'_>,
ui: &mut egui::Ui,
verbosity: UiVerbosity,
query: &re_arrow_store::LatestAtQuery,
) {
ClassicTensor::from(self).data_ui(ctx, ui, verbosity, query);
}
}

impl DataUi for ClassicTensor {
impl DataUi for Tensor {
fn data_ui(
&self,
ctx: &mut ViewerContext<'_>,
Expand Down Expand Up @@ -114,7 +99,7 @@ impl DataUi for ClassicTensor {
pub fn tensor_dtype_and_shape_ui_grid_contents(
re_ui: &re_ui::ReUi,
ui: &mut egui::Ui,
tensor: &ClassicTensor,
tensor: &Tensor,
tensor_stats: Option<&TensorStats>,
) {
re_ui
Expand Down Expand Up @@ -158,7 +143,7 @@ pub fn tensor_dtype_and_shape_ui_grid_contents(
pub fn tensor_dtype_and_shape_ui(
re_ui: &re_ui::ReUi,
ui: &mut egui::Ui,
tensor: &ClassicTensor,
tensor: &Tensor,
tensor_stats: Option<&TensorStats>,
) {
egui::Grid::new("tensor_dtype_and_shape_ui")
Expand Down Expand Up @@ -452,11 +437,13 @@ fn histogram_ui(ui: &mut egui::Ui, rgb_image: &image::RgbImage) -> egui::Respons
#[cfg(not(target_arch = "wasm32"))]
fn image_options(
ui: &mut egui::Ui,
tensor: &re_log_types::ClassicTensor,
tensor: &re_log_types::component_types::Tensor,
dynamic_image: &image::DynamicImage,
) {
// TODO(emilk): support copying images on web

use re_log_types::component_types::TensorData;

#[cfg(not(target_arch = "wasm32"))]
if ui.button("Click to copy image").clicked() {
let rgba = dynamic_image.to_rgba8();
Expand All @@ -471,39 +458,37 @@ fn image_options(
// TODO(emilk): support saving images on web
#[cfg(not(target_arch = "wasm32"))]
if ui.button("Save image…").clicked() {
use re_log_types::TensorDataStore;

match &tensor.data {
TensorDataStore::Dense(_) => {
TensorData::JPEG(bytes) => {
if let Some(path) = rfd::FileDialog::new()
.set_file_name("image.png")
.set_file_name("image.jpg")
.save_file()
{
match dynamic_image.save(&path) {
// TODO(emilk): show a popup instead of logging result
match write_binary(&path, bytes.as_slice()) {
Ok(()) => {
re_log::info!("Image saved to {path:?}");
}
Err(err) => {
re_log::error!("Failed saving image to {path:?}: {err}");
re_log::error!(
"Failed saving image to {path:?}: {}",
re_error::format(&err)
);
}
}
}
}
TensorDataStore::Jpeg(bytes) => {
_ => {
if let Some(path) = rfd::FileDialog::new()
.set_file_name("image.jpg")
.set_file_name("image.png")
.save_file()
{
match write_binary(&path, bytes) {
match dynamic_image.save(&path) {
// TODO(emilk): show a popup instead of logging result
Ok(()) => {
re_log::info!("Image saved to {path:?}");
}
Err(err) => {
re_log::error!(
"Failed saving image to {path:?}: {}",
re_error::format(&err)
);
re_log::error!("Failed saving image to {path:?}: {err}");
}
}
}
Expand Down
8 changes: 2 additions & 6 deletions crates/re_viewer/src/ui/view_tensor/scene.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use re_arrow_store::LatestAtQuery;
use re_data_store::{EntityPath, EntityProperties, InstancePath};
use re_log_types::{
component_types::{InstanceKey, Tensor},
ClassicTensor,
};
use re_log_types::component_types::{InstanceKey, Tensor, TensorTrait};
use re_query::{query_entity_with_primary, EntityView, QueryError};

use crate::{misc::ViewerContext, ui::SceneQuery};
Expand All @@ -13,7 +10,7 @@ use crate::{misc::ViewerContext, ui::SceneQuery};
/// A tensor scene, with everything needed to render it.
#[derive(Default)]
pub struct SceneTensor {
pub tensors: std::collections::BTreeMap<InstancePath, ClassicTensor>,
pub tensors: std::collections::BTreeMap<InstancePath, Tensor>,
}

impl SceneTensor {
Expand Down Expand Up @@ -47,7 +44,6 @@ impl SceneTensor {
entity_view: &EntityView<Tensor>,
) -> Result<(), QueryError> {
entity_view.visit1(|instance_key: InstanceKey, tensor: Tensor| {
let tensor = ClassicTensor::from(&tensor);
if !tensor.is_shaped_like_an_image() {
let instance_path = InstancePath::instance(ent_path.clone(), instance_key);
self.tensors.insert(instance_path, tensor);
Expand Down
Loading