Skip to content

Commit 45ebf13

Browse files
Wumpfjprochazk
authored andcommitted
Show tensors shaped [H, W, 1, 1] as images (and more!) (#2075)
* Ignore trailing tensor dimension when determining whether a tensor is an image. Fixes #1871 Quite a bit of nuance to support single channel 1x1 images & line-like images. * fix image preview for images other than M x N x C and M x N * comment fix * better shape_short comment * handle empty tensors * is_shaped_like_an_image is now defined via image_height_width_channels, improve comment on both * any Nx1x... image now now treated as image * rename to get_with_image_coords, make it more restrictive * tensor_to_gpu height_width_depth utility uses now tensor.image_height_width_channels * change behavior of is_vector
1 parent a557354 commit 45ebf13

File tree

5 files changed

+234
-62
lines changed

5 files changed

+234
-62
lines changed

crates/re_data_ui/src/image.rs

+19-15
Original file line numberDiff line numberDiff line change
@@ -541,17 +541,17 @@ fn tensor_pixel_value_ui(
541541
}
542542
});
543543

544-
let text = match tensor.num_dim() {
545-
2 => tensor.get(&[y, x]).map(|v| format!("Val: {v}")),
546-
3 => match tensor.shape()[2].size {
547-
0 => Some("Cannot preview 0-size channel".to_owned()),
548-
1 => tensor.get(&[y, x, 0]).map(|v| format!("Val: {v}")),
544+
let text = if let Some([_, _, channel]) = tensor.image_height_width_channels() {
545+
match channel {
546+
1 => tensor
547+
.get_with_image_coords(x, y, 0)
548+
.map(|v| format!("Val: {v}")),
549549
3 => {
550550
// TODO(jleibs): Track RGB ordering somehow -- don't just assume it
551551
if let (Some(r), Some(g), Some(b)) = (
552-
tensor.get(&[y, x, 0]),
553-
tensor.get(&[y, x, 1]),
554-
tensor.get(&[y, x, 2]),
552+
tensor.get_with_image_coords(x, y, 0),
553+
tensor.get_with_image_coords(x, y, 1),
554+
tensor.get_with_image_coords(x, y, 2),
555555
) {
556556
match (r, g, b) {
557557
(TensorElement::U8(r), TensorElement::U8(g), TensorElement::U8(b)) => {
@@ -566,10 +566,10 @@ fn tensor_pixel_value_ui(
566566
4 => {
567567
// TODO(jleibs): Track RGB ordering somehow -- don't just assume it
568568
if let (Some(r), Some(g), Some(b), Some(a)) = (
569-
tensor.get(&[y, x, 0]),
570-
tensor.get(&[y, x, 1]),
571-
tensor.get(&[y, x, 2]),
572-
tensor.get(&[y, x, 3]),
569+
tensor.get_with_image_coords(x, y, 0),
570+
tensor.get_with_image_coords(x, y, 1),
571+
tensor.get_with_image_coords(x, y, 2),
572+
tensor.get_with_image_coords(x, y, 3),
573573
) {
574574
match (r, g, b, a) {
575575
(
@@ -586,9 +586,13 @@ fn tensor_pixel_value_ui(
586586
None
587587
}
588588
}
589-
channels => Some(format!("Cannot preview {channels}-channel image")),
590-
},
591-
dims => Some(format!("Cannot preview {dims}-dimensional image")),
589+
channel => Some(format!("Cannot preview {channel}-size channel image")),
590+
}
591+
} else {
592+
Some(format!(
593+
"Cannot preview tensors with a shape of {:?}",
594+
tensor.shape()
595+
))
592596
};
593597

594598
if let Some(text) = text {

crates/re_log_types/src/component_types/tensor.rs

+206-22
Original file line numberDiff line numberDiff line change
@@ -407,50 +407,118 @@ impl Tensor {
407407
self.shape.as_slice()
408408
}
409409

410+
/// Returns the shape of the tensor with all trailing dimensions of size 1 ignored.
411+
///
412+
/// If all dimension sizes are one, this returns only the first dimension.
413+
#[inline]
414+
pub fn shape_short(&self) -> &[TensorDimension] {
415+
if self.shape.is_empty() {
416+
&self.shape
417+
} else {
418+
self.shape
419+
.iter()
420+
.enumerate()
421+
.rev()
422+
.find(|(_, dim)| dim.size != 1)
423+
.map_or(&self.shape[0..1], |(i, _)| &self.shape[..(i + 1)])
424+
}
425+
}
426+
410427
#[inline]
411428
pub fn num_dim(&self) -> usize {
412429
self.shape.len()
413430
}
414431

415-
/// If this tensor is shaped as an image, return the height, width, and channels/depth of it.
432+
/// If the tensor can be interpreted as an image, return the height, width, and channels/depth of it.
416433
pub fn image_height_width_channels(&self) -> Option<[u64; 3]> {
417-
if self.shape.len() == 2 {
418-
Some([self.shape[0].size, self.shape[1].size, 1])
419-
} else if self.shape.len() == 3 {
420-
let channels = self.shape[2].size;
421-
// gray, rgb, rgba
422-
if matches!(channels, 1 | 3 | 4) {
423-
Some([self.shape[0].size, self.shape[1].size, channels])
424-
} else {
425-
None
434+
let shape_short = self.shape_short();
435+
436+
match shape_short.len() {
437+
1 => {
438+
// Special case: Nx1(x1x1x...) tensors are treated as Nx1 grey images.
439+
if self.shape.len() >= 2 {
440+
Some([shape_short[0].size, 1, 1])
441+
} else {
442+
None
443+
}
426444
}
427-
} else {
428-
None
445+
2 => Some([shape_short[0].size, shape_short[1].size, 1]),
446+
3 => {
447+
let channels = shape_short[2].size;
448+
if matches!(channels, 3 | 4) {
449+
// rgb, rgba
450+
Some([shape_short[0].size, shape_short[1].size, channels])
451+
} else {
452+
None
453+
}
454+
}
455+
_ => None,
429456
}
430457
}
431458

459+
/// Returns true if the tensor can be interpreted as an image.
432460
pub fn is_shaped_like_an_image(&self) -> bool {
433-
self.num_dim() == 2
434-
|| self.num_dim() == 3 && {
435-
matches!(
436-
self.shape.last().unwrap().size,
437-
// gray, rgb, rgba
438-
1 | 3 | 4
439-
)
440-
}
461+
self.image_height_width_channels().is_some()
441462
}
442463

464+
/// Returns true if either all dimensions have size 1 or only a single dimension has a size larger than 1.
465+
///
466+
/// Empty tensors return false.
443467
#[inline]
444468
pub fn is_vector(&self) -> bool {
445-
let shape = &self.shape;
446-
shape.len() == 1 || { shape.len() == 2 && (shape[0].size == 1 || shape[1].size == 1) }
469+
if self.shape.is_empty() {
470+
false
471+
} else {
472+
self.shape.iter().filter(|dim| dim.size > 1).count() <= 1
473+
}
447474
}
448475

449476
#[inline]
450477
pub fn meaning(&self) -> TensorDataMeaning {
451478
self.meaning
452479
}
453480

481+
/// Query with x, y, channel indices.
482+
///
483+
/// Allows to query values for any image like tensor even if it has more or less dimensions than 3.
484+
/// (useful for sampling e.g. `N x M x C x 1` tensor which is a valid image)
485+
#[inline]
486+
pub fn get_with_image_coords(&self, x: u64, y: u64, channel: u64) -> Option<TensorElement> {
487+
match self.shape.len() {
488+
1 => {
489+
if y == 0 && channel == 0 {
490+
self.get(&[x])
491+
} else {
492+
None
493+
}
494+
}
495+
2 => {
496+
if channel == 0 {
497+
self.get(&[y, x])
498+
} else {
499+
None
500+
}
501+
}
502+
3 => self.get(&[y, x, channel]),
503+
4 => {
504+
// Optimization for common case, next case handles this too.
505+
if self.shape[3].size == 1 {
506+
self.get(&[y, x, channel, 0])
507+
} else {
508+
None
509+
}
510+
}
511+
dim => self.image_height_width_channels().and_then(|_| {
512+
self.get(
513+
&[x, y, channel]
514+
.into_iter()
515+
.chain(std::iter::repeat(0).take(dim - 3))
516+
.collect::<Vec<u64>>(),
517+
)
518+
}),
519+
}
520+
}
521+
454522
pub fn get(&self, index: &[u64]) -> Option<TensorElement> {
455523
let mut stride: usize = 1;
456524
let mut offset: usize = 0;
@@ -1164,3 +1232,119 @@ fn test_arrow() {
11641232
let tensors_out: Vec<Tensor> = TryIntoCollection::try_into_collection(array).unwrap();
11651233
assert_eq!(tensors_in, tensors_out);
11661234
}
1235+
1236+
#[test]
1237+
fn test_tensor_shape_utilities() {
1238+
fn generate_tensor_from_shape(sizes: &[u64]) -> Tensor {
1239+
let shape = sizes
1240+
.iter()
1241+
.map(|&size| TensorDimension { size, name: None })
1242+
.collect();
1243+
let num_elements = sizes.iter().fold(0, |acc, &size| acc * size);
1244+
let data = (0..num_elements).map(|i| i as u32).collect::<Vec<_>>();
1245+
1246+
Tensor {
1247+
tensor_id: TensorId(std::default::Default::default()),
1248+
shape,
1249+
data: TensorData::U32(data.into()),
1250+
meaning: TensorDataMeaning::Unknown,
1251+
meter: None,
1252+
}
1253+
}
1254+
1255+
// Empty tensor.
1256+
{
1257+
let tensor = generate_tensor_from_shape(&[]);
1258+
1259+
assert_eq!(tensor.image_height_width_channels(), None);
1260+
assert_eq!(tensor.shape_short(), tensor.shape());
1261+
assert!(!tensor.is_vector());
1262+
assert!(!tensor.is_shaped_like_an_image());
1263+
}
1264+
1265+
// Single dimension tensors.
1266+
for shape in [vec![4], vec![1]] {
1267+
let tensor = generate_tensor_from_shape(&shape);
1268+
1269+
assert_eq!(tensor.image_height_width_channels(), None);
1270+
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
1271+
assert!(tensor.is_vector());
1272+
assert!(!tensor.is_shaped_like_an_image());
1273+
}
1274+
1275+
// Single element, but it might be interpreted as a 1x1 grey image!
1276+
for shape in [
1277+
vec![1, 1],
1278+
vec![1, 1, 1],
1279+
vec![1, 1, 1, 1],
1280+
vec![1, 1, 1, 1, 1],
1281+
] {
1282+
let tensor = generate_tensor_from_shape(&shape);
1283+
1284+
assert_eq!(tensor.image_height_width_channels(), Some([1, 1, 1]));
1285+
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
1286+
assert!(tensor.is_vector());
1287+
assert!(tensor.is_shaped_like_an_image());
1288+
}
1289+
// Color/Grey 2x4 images
1290+
for shape in [
1291+
vec![4, 2],
1292+
vec![4, 2, 1],
1293+
vec![4, 2, 1, 1],
1294+
vec![4, 2, 3],
1295+
vec![4, 2, 3, 1, 1],
1296+
vec![4, 2, 4],
1297+
vec![4, 2, 4, 1, 1, 1, 1],
1298+
] {
1299+
let tensor = generate_tensor_from_shape(&shape);
1300+
let channels = shape.get(2).cloned().unwrap_or(1);
1301+
1302+
assert_eq!(tensor.image_height_width_channels(), Some([4, 2, channels]));
1303+
assert_eq!(
1304+
tensor.shape_short(),
1305+
&tensor.shape()[0..(2 + (channels != 1) as usize)]
1306+
);
1307+
assert!(!tensor.is_vector());
1308+
assert!(tensor.is_shaped_like_an_image());
1309+
}
1310+
1311+
// Grey 1x4 images
1312+
for shape in [
1313+
vec![4, 1],
1314+
vec![4, 1, 1],
1315+
vec![4, 1, 1, 1],
1316+
vec![4, 1, 1, 1, 1],
1317+
] {
1318+
let tensor = generate_tensor_from_shape(&shape);
1319+
1320+
assert_eq!(tensor.image_height_width_channels(), Some([4, 1, 1]));
1321+
assert_eq!(tensor.shape_short(), &tensor.shape()[0..1]);
1322+
assert!(tensor.is_vector());
1323+
assert!(tensor.is_shaped_like_an_image());
1324+
}
1325+
1326+
// Grey 4x1 images
1327+
for shape in [
1328+
vec![1, 4],
1329+
vec![1, 4, 1],
1330+
vec![1, 4, 1, 1],
1331+
vec![1, 4, 1, 1, 1],
1332+
] {
1333+
let tensor = generate_tensor_from_shape(&shape);
1334+
1335+
assert_eq!(tensor.image_height_width_channels(), Some([1, 4, 1]));
1336+
assert_eq!(tensor.shape_short(), &tensor.shape()[0..2]);
1337+
assert!(tensor.is_vector());
1338+
assert!(tensor.is_shaped_like_an_image());
1339+
}
1340+
1341+
// Non images & non vectors without trailing dimensions
1342+
for shape in [vec![4, 2, 5], vec![1, 1, 1, 2, 4]] {
1343+
let tensor = generate_tensor_from_shape(&shape);
1344+
1345+
assert_eq!(tensor.image_height_width_channels(), None);
1346+
assert_eq!(tensor.shape_short(), tensor.shape());
1347+
assert!(!tensor.is_vector());
1348+
assert!(!tensor.is_shaped_like_an_image());
1349+
}
1350+
}

crates/re_viewer/src/ui/space_view_heuristics.rs

+2-5
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,7 @@ fn default_created_space_views_from_candidates(
188188
&[],
189189
) {
190190
for tensor in entity_view.iter_primary_flattened() {
191-
if tensor.is_shaped_like_an_image() {
192-
debug_assert!(matches!(tensor.shape.len(), 2 | 3));
193-
191+
if let Some([height, width, _]) = tensor.image_height_width_channels() {
194192
if query_latest_single::<re_log_types::DrawOrder>(
195193
entity_db,
196194
entity_path,
@@ -205,9 +203,8 @@ fn default_created_space_views_from_candidates(
205203
.push(entity_path.clone());
206204
} else {
207205
// Otherwise, distinguish buckets by image size.
208-
let dim = (tensor.shape[0].size, tensor.shape[1].size);
209206
images_by_bucket
210-
.entry(ImageBucketing::BySize(dim))
207+
.entry(ImageBucketing::BySize((height, width)))
211208
.or_default()
212209
.push(entity_path.clone());
213210
}

crates/re_viewer_context/src/gpu_bridge/tensor_to_gpu.rs

+6-19
Original file line numberDiff line numberDiff line change
@@ -471,27 +471,14 @@ fn pad_and_narrow_and_cast<T: Copy + Pod>(
471471
fn height_width_depth(tensor: &Tensor) -> anyhow::Result<[u32; 3]> {
472472
use anyhow::Context as _;
473473

474-
let shape = &tensor.shape();
475-
476-
anyhow::ensure!(
477-
shape.len() == 2 || shape.len() == 3,
478-
"Expected a 2D or 3D tensor, got {shape:?}",
479-
);
474+
let Some([height, width, channel]) = tensor.image_height_width_channels() else {
475+
anyhow::bail!("Tensor is not an image");
476+
};
480477

481478
let [height, width] = [
482-
u32::try_from(shape[0].size).context("tensor too large")?,
483-
u32::try_from(shape[1].size).context("tensor too large")?,
479+
u32::try_from(height).context("Image height is too large")?,
480+
u32::try_from(width).context("Image width is too large")?,
484481
];
485-
let depth = if shape.len() == 2 { 1 } else { shape[2].size };
486-
487-
anyhow::ensure!(
488-
depth == 1 || depth == 3 || depth == 4,
489-
"Expected depth of 1,3,4 (gray, RGB, RGBA), found {depth:?}. Tensor shape: {shape:?}"
490-
);
491-
debug_assert!(
492-
tensor.is_shaped_like_an_image(),
493-
"We should make the same checks above, but with actual error messages"
494-
);
495482

496-
Ok([height, width, depth as u32])
483+
Ok([height, width, channel as u32])
497484
}

examples/rust/api_demo/src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ fn demo_rects(rec_stream: &RecordingStream) -> anyhow::Result<()> {
251251
use ndarray_rand::{rand_distr::Uniform, RandomExt as _};
252252

253253
// Add an image
254-
let img = Array::<u8, _>::from_elem((1024, 1024, 3).f(), 128);
254+
let img = Array::<u8, _>::from_elem((1024, 1024, 3, 1).f(), 128);
255255
MsgSender::new("rects_demo/img")
256256
.with_timepoint(sim_time(1 as _))
257257
.with_component(&[Tensor::try_from(img.as_standard_layout().view())?])?

0 commit comments

Comments
 (0)