Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show tensors shaped [H, W, 1, 1] as images (and more!) #2075

Merged
merged 11 commits into from
May 10, 2023
34 changes: 19 additions & 15 deletions crates/re_data_ui/src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,17 +540,17 @@ fn tensor_pixel_value_ui(
}
});

let text = match tensor.num_dim() {
2 => tensor.get(&[y, x]).map(|v| format!("Val: {v}")),
3 => match tensor.shape()[2].size {
0 => Some("Cannot preview 0-size channel".to_owned()),
1 => tensor.get(&[y, x, 0]).map(|v| format!("Val: {v}")),
let text = if let Some([_, _, channel]) = tensor.image_height_width_channels() {
match channel {
1 => tensor
.get_with_image_coords(x, y, 0)
.map(|v| format!("Val: {v}")),
3 => {
// TODO(jleibs): Track RGB ordering somehow -- don't just assume it
if let (Some(r), Some(g), Some(b)) = (
tensor.get(&[y, x, 0]),
tensor.get(&[y, x, 1]),
tensor.get(&[y, x, 2]),
tensor.get_with_image_coords(x, y, 0),
tensor.get_with_image_coords(x, y, 1),
tensor.get_with_image_coords(x, y, 2),
) {
match (r, g, b) {
(TensorElement::U8(r), TensorElement::U8(g), TensorElement::U8(b)) => {
Expand All @@ -565,10 +565,10 @@ fn tensor_pixel_value_ui(
4 => {
// TODO(jleibs): Track RGB ordering somehow -- don't just assume it
if let (Some(r), Some(g), Some(b), Some(a)) = (
tensor.get(&[y, x, 0]),
tensor.get(&[y, x, 1]),
tensor.get(&[y, x, 2]),
tensor.get(&[y, x, 3]),
tensor.get_with_image_coords(x, y, 0),
tensor.get_with_image_coords(x, y, 1),
tensor.get_with_image_coords(x, y, 2),
tensor.get_with_image_coords(x, y, 3),
) {
match (r, g, b, a) {
(
Expand All @@ -585,9 +585,13 @@ fn tensor_pixel_value_ui(
None
}
}
channels => Some(format!("Cannot preview {channels}-channel image")),
},
dims => Some(format!("Cannot preview {dims}-dimensional image")),
channel => Some(format!("Cannot preview {channel}-size channel image")),
}
} else {
Some(format!(
"Cannot preview tensors with a shape of {:?}",
tensor.shape()
))
};

if let Some(text) = text {
Expand Down
228 changes: 206 additions & 22 deletions crates/re_log_types/src/component_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,50 +399,118 @@ impl Tensor {
self.shape.as_slice()
}

/// Returns the shape of the tensor with all trailing dimensions of size 1 ignored.
Wumpf marked this conversation as resolved.
Show resolved Hide resolved
///
/// If all dimension sizes are one, this returns only the first dimension.
#[inline]
pub fn shape_short(&self) -> &[TensorDimension] {
if self.shape.is_empty() {
&self.shape
} else {
self.shape
.iter()
.enumerate()
.rev()
.find(|(_, dim)| dim.size != 1)
.map_or(&self.shape[0..1], |(i, _)| &self.shape[..(i + 1)])
}
}

#[inline]
pub fn num_dim(&self) -> usize {
self.shape.len()
}

/// If this tensor is shaped as an image, return the height, width, and channels/depth of it.
/// If the tensor can be interpreted as an image, return the height, width, and channels/depth of it.
pub fn image_height_width_channels(&self) -> Option<[u64; 3]> {
if self.shape.len() == 2 {
Some([self.shape[0].size, self.shape[1].size, 1])
} else if self.shape.len() == 3 {
let channels = self.shape[2].size;
// gray, rgb, rgba
if matches!(channels, 1 | 3 | 4) {
Some([self.shape[0].size, self.shape[1].size, channels])
} else {
None
let shape_short = self.shape_short();

match shape_short.len() {
1 => {
Comment on lines +428 to +429
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if shape_short would return just the integers (not the names) we could make this match a lot nicer;

Suggested change
match shape_short.len() {
1 => {
match shape_short {
[h] => {},
[h, w] => {},
[h, w, c] => {},

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

…but it would require an allocation, or that we change how shape is stored (see #1992 (comment)), so let's not right now

// Special case: Nx1(x1x1x...) tensors are treated as Nx1 grey images.
if self.shape.len() >= 2 {
Some([shape_short[0].size, 1, 1])
} else {
None
}
}
} else {
None
2 => Some([shape_short[0].size, shape_short[1].size, 1]),
3 => {
let channels = shape_short[2].size;
if matches!(channels, 3 | 4) {
// rgb, rgba
Some([shape_short[0].size, shape_short[1].size, channels])
} else {
None
}
}
_ => None,
}
}

/// Returns true if the tensor can be interpreted as an image.
pub fn is_shaped_like_an_image(&self) -> bool {
Wumpf marked this conversation as resolved.
Show resolved Hide resolved
self.num_dim() == 2
|| self.num_dim() == 3 && {
matches!(
self.shape.last().unwrap().size,
// gray, rgb, rgba
1 | 3 | 4
)
}
self.image_height_width_channels().is_some()
}

/// Returns true if either all dimensions have size 1 or only a single dimension has a size larger than 1.
///
/// Empty tensors return false.
#[inline]
pub fn is_vector(&self) -> bool {
let shape = &self.shape;
shape.len() == 1 || { shape.len() == 2 && (shape[0].size == 1 || shape[1].size == 1) }
if self.shape.is_empty() {
false
} else {
self.shape.iter().filter(|dim| dim.size > 1).count() <= 1
Wumpf marked this conversation as resolved.
Show resolved Hide resolved
}
}

#[inline]
pub fn meaning(&self) -> TensorDataMeaning {
self.meaning
}

/// Query with x, y, channel indices.
///
/// Allows to query values for any image like tensor even if it has more or less dimensions than 3.
/// (useful for sampling e.g. `N x M x C x 1` tensor which is a valid image)
#[inline]
pub fn get_with_image_coords(&self, x: u64, y: u64, channel: u64) -> Option<TensorElement> {
match self.shape.len() {
1 => {
if y == 0 && channel == 0 {
self.get(&[x])
} else {
None
}
}
2 => {
if channel == 0 {
self.get(&[y, x])
} else {
None
}
}
3 => self.get(&[y, x, channel]),
4 => {
// Optimization for common case, next case handles this too.
if self.shape[3].size == 1 {
self.get(&[y, x, channel, 0])
} else {
None
}
}
dim => self.image_height_width_channels().and_then(|_| {
self.get(
&[x, y, channel]
.into_iter()
.chain(std::iter::repeat(0).take(dim - 3))
.collect::<Vec<u64>>(),
)
}),
Comment on lines +503 to +510
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should return None if using get_with_image_coords on a tensor that is non-image like, e.g. 8x7x6x5x4x3

So maybe just check if self.short_shape().len() <= 3 in this branch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the check in here implied by self.image_height_width_channels().and_then is actually stronger and would also reject e.g. 3x3x5
So I think this is alright this way (albeit more expensive)

}
}

pub fn get(&self, index: &[u64]) -> Option<TensorElement> {
let mut stride: usize = 1;
let mut offset: usize = 0;
Expand Down Expand Up @@ -1080,3 +1148,119 @@ fn test_arrow() {
let tensors_out: Vec<Tensor> = TryIntoCollection::try_into_collection(array).unwrap();
assert_eq!(tensors_in, tensors_out);
}

#[test]
fn test_tensor_shape_utilities() {
fn generate_tensor_from_shape(sizes: &[u64]) -> Tensor {
let shape = sizes
.iter()
.map(|&size| TensorDimension { size, name: None })
.collect();
let num_elements = sizes.iter().fold(0, |acc, &size| acc * size);
let data = (0..num_elements).map(|i| i as u32).collect::<Vec<_>>();

Tensor {
tensor_id: TensorId(std::default::Default::default()),
shape,
data: TensorData::U32(data.into()),
meaning: TensorDataMeaning::Unknown,
meter: None,
}
}

Wumpf marked this conversation as resolved.
Show resolved Hide resolved
// Empty tensor.
{
let tensor = generate_tensor_from_shape(&[]);

assert_eq!(tensor.image_height_width_channels(), None);
assert_eq!(tensor.shape_short(), tensor.shape());
assert!(!tensor.is_vector());
assert!(!tensor.is_shaped_like_an_image());
Wumpf marked this conversation as resolved.
Show resolved Hide resolved
}

// Single dimension tensors.
for shape in [vec![4], vec![1]] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), None);
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
assert!(tensor.is_vector());
assert!(!tensor.is_shaped_like_an_image());
}

// Single element, but it might be interpreted as a 1x1 grey image!
for shape in [
vec![1, 1],
vec![1, 1, 1],
vec![1, 1, 1, 1],
vec![1, 1, 1, 1, 1],
] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), Some([1, 1, 1]));
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
assert!(tensor.is_vector());
assert!(tensor.is_shaped_like_an_image());
}
// Color/Grey 2x4 images
for shape in [
vec![4, 2],
vec![4, 2, 1],
vec![4, 2, 1, 1],
vec![4, 2, 3],
vec![4, 2, 3, 1, 1],
vec![4, 2, 4],
vec![4, 2, 4, 1, 1, 1, 1],
] {
let tensor = generate_tensor_from_shape(&shape);
let channels = shape.get(2).cloned().unwrap_or(1);

assert_eq!(tensor.image_height_width_channels(), Some([4, 2, channels]));
assert_eq!(
tensor.shape_short(),
&tensor.shape()[0..(2 + (channels != 1) as usize)]
);
assert!(!tensor.is_vector());
assert!(tensor.is_shaped_like_an_image());
}

// Grey 1x4 images
for shape in [
vec![4, 1],
vec![4, 1, 1],
vec![4, 1, 1, 1],
vec![4, 1, 1, 1, 1],
] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), Some([4, 1, 1]));
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
assert!(tensor.is_vector());
assert!(tensor.is_shaped_like_an_image());
}

// Grey 4x1 images
for shape in [
vec![1, 4],
vec![1, 4, 1],
vec![1, 4, 1, 1],
vec![1, 4, 1, 1, 1],
] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), Some([1, 4, 1]));
assert_eq!(tensor.shape_short(), &tensor.shape()[0..2]);
assert!(tensor.is_vector());
assert!(tensor.is_shaped_like_an_image());
}

// Non images & non vectors without trailing dimensions
for shape in [vec![4, 2, 5], vec![1, 1, 1, 2, 4]] {
let tensor = generate_tensor_from_shape(&shape);

assert_eq!(tensor.image_height_width_channels(), None);
assert_eq!(tensor.shape_short(), tensor.shape());
assert!(!tensor.is_vector());
assert!(!tensor.is_shaped_like_an_image());
}
}
7 changes: 2 additions & 5 deletions crates/re_viewer/src/ui/space_view_heuristics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,7 @@ fn default_created_space_views_from_candidates(
&[],
) {
for tensor in entity_view.iter_primary_flattened() {
if tensor.is_shaped_like_an_image() {
debug_assert!(matches!(tensor.shape.len(), 2 | 3));

if let Some([height, width, _]) = tensor.image_height_width_channels() {
if query_latest_single::<re_log_types::DrawOrder>(
entity_db,
entity_path,
Expand All @@ -205,9 +203,8 @@ fn default_created_space_views_from_candidates(
.push(entity_path.clone());
} else {
// Otherwise, distinguish buckets by image size.
let dim = (tensor.shape[0].size, tensor.shape[1].size);
images_by_bucket
.entry(ImageBucketing::BySize(dim))
.entry(ImageBucketing::BySize((height, width)))
.or_default()
.push(entity_path.clone());
}
Expand Down
25 changes: 6 additions & 19 deletions crates/re_viewer_context/src/gpu_bridge/tensor_to_gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,27 +465,14 @@ fn pad_and_narrow_and_cast<T: Copy + Pod>(
fn height_width_depth(tensor: &Tensor) -> anyhow::Result<[u32; 3]> {
use anyhow::Context as _;

let shape = &tensor.shape();

anyhow::ensure!(
shape.len() == 2 || shape.len() == 3,
"Expected a 2D or 3D tensor, got {shape:?}",
);
let Some([height, width, channel]) = tensor.image_height_width_channels() else {
anyhow::bail!("Tensor is not an image");
};

let [height, width] = [
u32::try_from(shape[0].size).context("tensor too large")?,
u32::try_from(shape[1].size).context("tensor too large")?,
u32::try_from(height).context("Image height is too large")?,
u32::try_from(width).context("Image width is too large")?,
];
let depth = if shape.len() == 2 { 1 } else { shape[2].size };

anyhow::ensure!(
depth == 1 || depth == 3 || depth == 4,
"Expected depth of 1,3,4 (gray, RGB, RGBA), found {depth:?}. Tensor shape: {shape:?}"
);
debug_assert!(
tensor.is_shaped_like_an_image(),
"We should make the same checks above, but with actual error messages"
);

Ok([height, width, depth as u32])
Ok([height, width, channel as u32])
}
2 changes: 1 addition & 1 deletion examples/rust/api_demo/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ fn demo_rects(rec_stream: &RecordingStream) -> anyhow::Result<()> {
use ndarray_rand::{rand_distr::Uniform, RandomExt as _};

// Add an image
let img = Array::<u8, _>::from_elem((1024, 1024, 3).f(), 128);
let img = Array::<u8, _>::from_elem((1024, 1024, 3, 1).f(), 128);
Wumpf marked this conversation as resolved.
Show resolved Hide resolved
MsgSender::new("rects_demo/img")
.with_timepoint(sim_time(1 as _))
.with_component(&[Tensor::try_from(img.as_standard_layout().view())?])?
Expand Down