-
Notifications
You must be signed in to change notification settings - Fork 373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update tensor.rs to Handle Trailing Unit-Dimensions in Image Tensors #1916
Update tensor.rs to Handle Trailing Unit-Dimensions in Image Tensors #1916
Conversation
- Previously, tensor.rs only displayed tensors with shapes of [H, W, 1|3|4] as images, this update addresses the issue by modifying the `image_height_width_channels()`, `is_shaped_like_an_image()`, and `is_vector()` methods to ignore trailing unit-dimensions when checking the shape of the tensor - This change should improve the user experience by allowing a wider range of image tensor shapes to be displayed as images in tensor.rs
@@ -424,20 +435,33 @@ impl Tensor { | |||
} | |||
|
|||
pub fn is_shaped_like_an_image(&self) -> bool { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can simplify this quite a bit:
pub fn is_shaped_like_an_image(&self) -> bool { | |
pub fn is_shaped_like_an_image(&self) -> bool { | |
self.image_height_width_channels().is_some() | |
} |
} | ||
|
||
#[inline] | ||
pub fn is_vector(&self) -> bool { | ||
let shape = &self.shape; | ||
shape.len() == 1 || { shape.len() == 2 && (shape[0].size == 1 || shape[1].size == 1) } | ||
let mut last_dim = shape.last().unwrap().size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will panic if shape
is empty (for "scalar tensors")
let mut shape = self.shape.clone(); | ||
|
||
// Remove trailing unit dimensions | ||
while let Some(&d) = shape.last() { | ||
if d.size == 1 { | ||
shape.pop(); | ||
} else { | ||
break; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest encapsulating this in a helper function, e.g:
/// Returns the shape, but ignoring trailing dimensions of length 1.
///
/// For instance, a shape of `[640, 480, 1, 1]` would be returned as `[640, 480]`.
fn short_shape(&self) -> Vec<TensorDimension> {
…
}
this could then be used by image_height_width_channels
as well as is_vector
Thanks for the PR! It would be great if you added a simple test of this to |
@tauseefmohammed2 you're still on that? :) |
Took the task over in #2075 |
What
Previously, tensor.rs only displayed tensors with shapes of [H, W, 1|3|4] as images, this update addresses the issue by modifying the
image_height_width_channels()
,is_shaped_like_an_image()
, andis_vector()
methods to ignore trailing unit-dimensions when checking the shape of the tensorThis change should improve the user experience by allowing a wider range of image tensor shapes to be displayed as images in tensor.rs
Resolves Issue: Show tensors shaped
[H, W, 1, 1]
as images #1871Checklist