Skip to content

Commit

Permalink
Show tensors shaped [H, W, 1, 1] as images (and more!) (#2075)
Browse files Browse the repository at this point in the history
* Ignore trailing tensor dimension when determining whether a tensor is an image.
Fixes #1871
Quite a bit of nuance to support single channel 1x1 images & line-like images.

* fix image preview for images other than M x N x C and M x N

* comment fix

* better shape_short comment

* handle empty tensors

* is_shaped_like_an_image is now defined via image_height_width_channels, improve comment on both

* any Nx1x... image now now treated as image

* rename to get_with_image_coords, make it more restrictive

* tensor_to_gpu height_width_depth utility uses now tensor.image_height_width_channels

* change behavior of is_vector
  • Loading branch information
Wumpf authored and jprochazk committed May 11, 2023
1 parent a557354 commit 45ebf13
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 62 deletions.
34 changes: 19 additions & 15 deletions crates/re_data_ui/src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,17 +541,17 @@ fn tensor_pixel_value_ui(
}
});

let text = match tensor.num_dim() {
2 => tensor.get(&[y, x]).map(|v| format!("Val: {v}")),
3 => match tensor.shape()[2].size {
0 => Some("Cannot preview 0-size channel".to_owned()),
1 => tensor.get(&[y, x, 0]).map(|v| format!("Val: {v}")),
let text = if let Some([_, _, channel]) = tensor.image_height_width_channels() {
match channel {
1 => tensor
.get_with_image_coords(x, y, 0)
.map(|v| format!("Val: {v}")),
3 => {
// TODO(jleibs): Track RGB ordering somehow -- don't just assume it
if let (Some(r), Some(g), Some(b)) = (
tensor.get(&[y, x, 0]),
tensor.get(&[y, x, 1]),
tensor.get(&[y, x, 2]),
tensor.get_with_image_coords(x, y, 0),
tensor.get_with_image_coords(x, y, 1),
tensor.get_with_image_coords(x, y, 2),
) {
match (r, g, b) {
(TensorElement::U8(r), TensorElement::U8(g), TensorElement::U8(b)) => {
Expand All @@ -566,10 +566,10 @@ fn tensor_pixel_value_ui(
4 => {
// TODO(jleibs): Track RGB ordering somehow -- don't just assume it
if let (Some(r), Some(g), Some(b), Some(a)) = (
tensor.get(&[y, x, 0]),
tensor.get(&[y, x, 1]),
tensor.get(&[y, x, 2]),
tensor.get(&[y, x, 3]),
tensor.get_with_image_coords(x, y, 0),
tensor.get_with_image_coords(x, y, 1),
tensor.get_with_image_coords(x, y, 2),
tensor.get_with_image_coords(x, y, 3),
) {
match (r, g, b, a) {
(
Expand All @@ -586,9 +586,13 @@ fn tensor_pixel_value_ui(
None
}
}
channels => Some(format!("Cannot preview {channels}-channel image")),
},
dims => Some(format!("Cannot preview {dims}-dimensional image")),
channel => Some(format!("Cannot preview {channel}-size channel image")),
}
} else {
Some(format!(
"Cannot preview tensors with a shape of {:?}",
tensor.shape()
))
};

if let Some(text) = text {
Expand Down
228 changes: 206 additions & 22 deletions crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,50 +407,118 @@ impl Tensor {
self.shape.as_slice()
}

/// Returns the shape of the tensor with all trailing dimensions of size 1 ignored.
///
/// If all dimension sizes are one, this returns only the first dimension.
#[inline]
pub fn shape_short(&self) -> &[TensorDimension] {
if self.shape.is_empty() {
&self.shape
} else {
self.shape
.iter()
.enumerate()
.rev()
.find(|(_, dim)| dim.size != 1)
.map_or(&self.shape[0..1], |(i, _)| &self.shape[..(i + 1)])
}
}

#[inline]
pub fn num_dim(&self) -> usize {
self.shape.len()
}

/// If this tensor is shaped as an image, return the height, width, and channels/depth of it.
/// If the tensor can be interpreted as an image, return the height, width, and channels/depth of it.
pub fn image_height_width_channels(&self) -> Option<[u64; 3]> {
if self.shape.len() == 2 {
Some([self.shape[0].size, self.shape[1].size, 1])
} else if self.shape.len() == 3 {
let channels = self.shape[2].size;
// gray, rgb, rgba
if matches!(channels, 1 | 3 | 4) {
Some([self.shape[0].size, self.shape[1].size, channels])
} else {
None
let shape_short = self.shape_short();

match shape_short.len() {
1 => {
// Special case: Nx1(x1x1x...) tensors are treated as Nx1 grey images.
if self.shape.len() >= 2 {
Some([shape_short[0].size, 1, 1])
} else {
None
}
}
} else {
None
2 => Some([shape_short[0].size, shape_short[1].size, 1]),
3 => {
let channels = shape_short[2].size;
if matches!(channels, 3 | 4) {
// rgb, rgba
Some([shape_short[0].size, shape_short[1].size, channels])
} else {
None
}
}
_ => None,
}
}

/// Returns true if the tensor can be interpreted as an image.
pub fn is_shaped_like_an_image(&self) -> bool {
self.num_dim() == 2
|| self.num_dim() == 3 && {
matches!(
self.shape.last().unwrap().size,
// gray, rgb, rgba
1 | 3 | 4
)
}
self.image_height_width_channels().is_some()
}

/// Returns true if either all dimensions have size 1 or only a single dimension has a size larger than 1.
///
/// Empty tensors return false.
#[inline]
pub fn is_vector(&self) -> bool {
let shape = &self.shape;
shape.len() == 1 || { shape.len() == 2 && (shape[0].size == 1 || shape[1].size == 1) }
if self.shape.is_empty() {
false
} else {
self.shape.iter().filter(|dim| dim.size > 1).count() <= 1
}
}

#[inline]
pub fn meaning(&self) -> TensorDataMeaning {
self.meaning
}

/// Query with x, y, channel indices.
///
/// Allows to query values for any image like tensor even if it has more or less dimensions than 3.
/// (useful for sampling e.g. `N x M x C x 1` tensor which is a valid image)
#[inline]
pub fn get_with_image_coords(&self, x: u64, y: u64, channel: u64) -> Option<TensorElement> {
match self.shape.len() {
1 => {
if y == 0 && channel == 0 {
self.get(&[x])
} else {
None
}
}
2 => {
if channel == 0 {
self.get(&[y, x])
} else {
None
}
}
3 => self.get(&[y, x, channel]),
4 => {
// Optimization for common case, next case handles this too.
if self.shape[3].size == 1 {
self.get(&[y, x, channel, 0])
} else {
None
}
}
dim => self.image_height_width_channels().and_then(|_| {
self.get(
&[x, y, channel]
.into_iter()
.chain(std::iter::repeat(0).take(dim - 3))
.collect::<Vec<u64>>(),
)
}),
}
}

pub fn get(&self, index: &[u64]) -> Option<TensorElement> {
let mut stride: usize = 1;
let mut offset: usize = 0;
Expand Down Expand Up @@ -1164,3 +1232,119 @@ fn test_arrow() {
let tensors_out: Vec<Tensor> = TryIntoCollection::try_into_collection(array).unwrap();
assert_eq!(tensors_in, tensors_out);
}

#[test]
fn test_tensor_shape_utilities() {
fn generate_tensor_from_shape(sizes: &[u64]) -> Tensor {
let shape = sizes
.iter()
.map(|&size| TensorDimension { size, name: None })
.collect();
let num_elements = sizes.iter().fold(0, |acc, &size| acc * size);
let data = (0..num_elements).map(|i| i as u32).collect::<Vec<_>>();

Tensor {
tensor_id: TensorId(std::default::Default::default()),
shape,
data: TensorData::U32(data.into()),
meaning: TensorDataMeaning::Unknown,
meter: None,
}
}

// Empty tensor.
{
let tensor = generate_tensor_from_shape(&[]);

assert_eq!(tensor.image_height_width_channels(), None);
assert_eq!(tensor.shape_short(), tensor.shape());
assert!(!tensor.is_vector());
assert!(!tensor.is_shaped_like_an_image());
}

// Single dimension tensors.
for shape in [vec![4], vec![1]] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), None);
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
assert!(tensor.is_vector());
assert!(!tensor.is_shaped_like_an_image());
}

// Single element, but it might be interpreted as a 1x1 grey image!
for shape in [
vec![1, 1],
vec![1, 1, 1],
vec![1, 1, 1, 1],
vec![1, 1, 1, 1, 1],
] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), Some([1, 1, 1]));
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
assert!(tensor.is_vector());
assert!(tensor.is_shaped_like_an_image());
}
// Color/Grey 2x4 images
for shape in [
vec![4, 2],
vec![4, 2, 1],
vec![4, 2, 1, 1],
vec![4, 2, 3],
vec![4, 2, 3, 1, 1],
vec![4, 2, 4],
vec![4, 2, 4, 1, 1, 1, 1],
] {
let tensor = generate_tensor_from_shape(&shape);
let channels = shape.get(2).cloned().unwrap_or(1);

assert_eq!(tensor.image_height_width_channels(), Some([4, 2, channels]));
assert_eq!(
tensor.shape_short(),
&tensor.shape()[0..(2 + (channels != 1) as usize)]
);
assert!(!tensor.is_vector());
assert!(tensor.is_shaped_like_an_image());
}

// Grey 1x4 images
for shape in [
vec![4, 1],
vec![4, 1, 1],
vec![4, 1, 1, 1],
vec![4, 1, 1, 1, 1],
] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), Some([4, 1, 1]));
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
assert!(tensor.is_vector());
assert!(tensor.is_shaped_like_an_image());
}

// Grey 4x1 images
for shape in [
vec![1, 4],
vec![1, 4, 1],
vec![1, 4, 1, 1],
vec![1, 4, 1, 1, 1],
] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), Some([1, 4, 1]));
assert_eq!(tensor.shape_short(), &tensor.shape()[0..2]);
assert!(tensor.is_vector());
assert!(tensor.is_shaped_like_an_image());
}

// Non images & non vectors without trailing dimensions
for shape in [vec![4, 2, 5], vec![1, 1, 1, 2, 4]] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), None);
assert_eq!(tensor.shape_short(), tensor.shape());
assert!(!tensor.is_vector());
assert!(!tensor.is_shaped_like_an_image());
}
}
7 changes: 2 additions & 5 deletions crates/re_viewer/src/ui/space_view_heuristics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ fn default_created_space_views_from_candidates(
&[],
) {
for tensor in entity_view.iter_primary_flattened() {
if tensor.is_shaped_like_an_image() {
debug_assert!(matches!(tensor.shape.len(), 2 | 3));

if let Some([height, width, _]) = tensor.image_height_width_channels() {
if query_latest_single::<re_log_types::DrawOrder>(
entity_db,
entity_path,
Expand All @@ -205,9 +203,8 @@ fn default_created_space_views_from_candidates(
.push(entity_path.clone());
} else {
// Otherwise, distinguish buckets by image size.
let dim = (tensor.shape[0].size, tensor.shape[1].size);
images_by_bucket
.entry(ImageBucketing::BySize(dim))
.entry(ImageBucketing::BySize((height, width)))
.or_default()
.push(entity_path.clone());
}
Expand Down
25 changes: 6 additions & 19 deletions crates/re_viewer_context/src/gpu_bridge/tensor_to_gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,27 +471,14 @@ fn pad_and_narrow_and_cast<T: Copy + Pod>(
fn height_width_depth(tensor: &Tensor) -> anyhow::Result<[u32; 3]> {
use anyhow::Context as _;

let shape = &tensor.shape();

anyhow::ensure!(
shape.len() == 2 || shape.len() == 3,
"Expected a 2D or 3D tensor, got {shape:?}",
);
let Some([height, width, channel]) = tensor.image_height_width_channels() else {
anyhow::bail!("Tensor is not an image");
};

let [height, width] = [
u32::try_from(shape[0].size).context("tensor too large")?,
u32::try_from(shape[1].size).context("tensor too large")?,
u32::try_from(height).context("Image height is too large")?,
u32::try_from(width).context("Image width is too large")?,
];
let depth = if shape.len() == 2 { 1 } else { shape[2].size };

anyhow::ensure!(
depth == 1 || depth == 3 || depth == 4,
"Expected depth of 1,3,4 (gray, RGB, RGBA), found {depth:?}. Tensor shape: {shape:?}"
);
debug_assert!(
tensor.is_shaped_like_an_image(),
"We should make the same checks above, but with actual error messages"
);

Ok([height, width, depth as u32])
Ok([height, width, channel as u32])
}
2 changes: 1 addition & 1 deletion examples/rust/api_demo/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ fn demo_rects(rec_stream: &RecordingStream) -> anyhow::Result<()> {
use ndarray_rand::{rand_distr::Uniform, RandomExt as _};

// Add an image
let img = Array::<u8, _>::from_elem((1024, 1024, 3).f(), 128);
let img = Array::<u8, _>::from_elem((1024, 1024, 3, 1).f(), 128);
MsgSender::new("rects_demo/img")
.with_timepoint(sim_time(1 as _))
.with_component(&[Tensor::try_from(img.as_standard_layout().view())?])?
Expand Down

0 comments on commit 45ebf13

Please sign in to comment.