@@ -407,50 +407,118 @@ impl Tensor {
407
407
self . shape . as_slice ( )
408
408
}
409
409
410
+ /// Returns the shape of the tensor with all trailing dimensions of size 1 ignored.
411
+ ///
412
+ /// If all dimension sizes are one, this returns only the first dimension.
413
+ #[ inline]
414
+ pub fn shape_short ( & self ) -> & [ TensorDimension ] {
415
+ if self . shape . is_empty ( ) {
416
+ & self . shape
417
+ } else {
418
+ self . shape
419
+ . iter ( )
420
+ . enumerate ( )
421
+ . rev ( )
422
+ . find ( |( _, dim) | dim. size != 1 )
423
+ . map_or ( & self . shape [ 0 ..1 ] , |( i, _) | & self . shape [ ..( i + 1 ) ] )
424
+ }
425
+ }
426
+
410
427
#[ inline]
411
428
pub fn num_dim ( & self ) -> usize {
412
429
self . shape . len ( )
413
430
}
414
431
415
- /// If this tensor is shaped as an image, return the height, width, and channels/depth of it.
432
+ /// If the tensor can be interpreted as an image, return the height, width, and channels/depth of it.
416
433
pub fn image_height_width_channels ( & self ) -> Option < [ u64 ; 3 ] > {
417
- if self . shape . len ( ) == 2 {
418
- Some ( [ self . shape [ 0 ] . size , self . shape [ 1 ] . size , 1 ] )
419
- } else if self . shape . len ( ) == 3 {
420
- let channels = self . shape [ 2 ] . size ;
421
- // gray, rgb, rgba
422
- if matches ! ( channels, 1 | 3 | 4 ) {
423
- Some ( [ self . shape [ 0 ] . size , self . shape [ 1 ] . size , channels] )
424
- } else {
425
- None
434
+ let shape_short = self . shape_short ( ) ;
435
+
436
+ match shape_short. len ( ) {
437
+ 1 => {
438
+ // Special case: Nx1(x1x1x...) tensors are treated as Nx1 grey images.
439
+ if self . shape . len ( ) >= 2 {
440
+ Some ( [ shape_short[ 0 ] . size , 1 , 1 ] )
441
+ } else {
442
+ None
443
+ }
426
444
}
427
- } else {
428
- None
445
+ 2 => Some ( [ shape_short[ 0 ] . size , shape_short[ 1 ] . size , 1 ] ) ,
446
+ 3 => {
447
+ let channels = shape_short[ 2 ] . size ;
448
+ if matches ! ( channels, 3 | 4 ) {
449
+ // rgb, rgba
450
+ Some ( [ shape_short[ 0 ] . size , shape_short[ 1 ] . size , channels] )
451
+ } else {
452
+ None
453
+ }
454
+ }
455
+ _ => None ,
429
456
}
430
457
}
431
458
459
+ /// Returns true if the tensor can be interpreted as an image.
432
460
pub fn is_shaped_like_an_image ( & self ) -> bool {
433
- self . num_dim ( ) == 2
434
- || self . num_dim ( ) == 3 && {
435
- matches ! (
436
- self . shape. last( ) . unwrap( ) . size,
437
- // gray, rgb, rgba
438
- 1 | 3 | 4
439
- )
440
- }
461
+ self . image_height_width_channels ( ) . is_some ( )
441
462
}
442
463
464
+ /// Returns true if either all dimensions have size 1 or only a single dimension has a size larger than 1.
465
+ ///
466
+ /// Empty tensors return false.
443
467
#[ inline]
444
468
pub fn is_vector ( & self ) -> bool {
445
- let shape = & self . shape ;
446
- shape. len ( ) == 1 || { shape. len ( ) == 2 && ( shape[ 0 ] . size == 1 || shape[ 1 ] . size == 1 ) }
469
+ if self . shape . is_empty ( ) {
470
+ false
471
+ } else {
472
+ self . shape . iter ( ) . filter ( |dim| dim. size > 1 ) . count ( ) <= 1
473
+ }
447
474
}
448
475
449
476
#[ inline]
450
477
pub fn meaning ( & self ) -> TensorDataMeaning {
451
478
self . meaning
452
479
}
453
480
481
+ /// Query with x, y, channel indices.
482
+ ///
483
+ /// Allows to query values for any image like tensor even if it has more or less dimensions than 3.
484
+ /// (useful for sampling e.g. `N x M x C x 1` tensor which is a valid image)
485
+ #[ inline]
486
+ pub fn get_with_image_coords ( & self , x : u64 , y : u64 , channel : u64 ) -> Option < TensorElement > {
487
+ match self . shape . len ( ) {
488
+ 1 => {
489
+ if y == 0 && channel == 0 {
490
+ self . get ( & [ x] )
491
+ } else {
492
+ None
493
+ }
494
+ }
495
+ 2 => {
496
+ if channel == 0 {
497
+ self . get ( & [ y, x] )
498
+ } else {
499
+ None
500
+ }
501
+ }
502
+ 3 => self . get ( & [ y, x, channel] ) ,
503
+ 4 => {
504
+ // Optimization for common case, next case handles this too.
505
+ if self . shape [ 3 ] . size == 1 {
506
+ self . get ( & [ y, x, channel, 0 ] )
507
+ } else {
508
+ None
509
+ }
510
+ }
511
+ dim => self . image_height_width_channels ( ) . and_then ( |_| {
512
+ self . get (
513
+ & [ x, y, channel]
514
+ . into_iter ( )
515
+ . chain ( std:: iter:: repeat ( 0 ) . take ( dim - 3 ) )
516
+ . collect :: < Vec < u64 > > ( ) ,
517
+ )
518
+ } ) ,
519
+ }
520
+ }
521
+
454
522
pub fn get ( & self , index : & [ u64 ] ) -> Option < TensorElement > {
455
523
let mut stride: usize = 1 ;
456
524
let mut offset: usize = 0 ;
@@ -1164,3 +1232,119 @@ fn test_arrow() {
1164
1232
let tensors_out: Vec < Tensor > = TryIntoCollection :: try_into_collection ( array) . unwrap ( ) ;
1165
1233
assert_eq ! ( tensors_in, tensors_out) ;
1166
1234
}
1235
+
1236
+ #[ test]
1237
+ fn test_tensor_shape_utilities ( ) {
1238
+ fn generate_tensor_from_shape ( sizes : & [ u64 ] ) -> Tensor {
1239
+ let shape = sizes
1240
+ . iter ( )
1241
+ . map ( |& size| TensorDimension { size, name : None } )
1242
+ . collect ( ) ;
1243
+ let num_elements = sizes. iter ( ) . fold ( 0 , |acc, & size| acc * size) ;
1244
+ let data = ( 0 ..num_elements) . map ( |i| i as u32 ) . collect :: < Vec < _ > > ( ) ;
1245
+
1246
+ Tensor {
1247
+ tensor_id : TensorId ( std:: default:: Default :: default ( ) ) ,
1248
+ shape,
1249
+ data : TensorData :: U32 ( data. into ( ) ) ,
1250
+ meaning : TensorDataMeaning :: Unknown ,
1251
+ meter : None ,
1252
+ }
1253
+ }
1254
+
1255
+ // Empty tensor.
1256
+ {
1257
+ let tensor = generate_tensor_from_shape ( & [ ] ) ;
1258
+
1259
+ assert_eq ! ( tensor. image_height_width_channels( ) , None ) ;
1260
+ assert_eq ! ( tensor. shape_short( ) , tensor. shape( ) ) ;
1261
+ assert ! ( !tensor. is_vector( ) ) ;
1262
+ assert ! ( !tensor. is_shaped_like_an_image( ) ) ;
1263
+ }
1264
+
1265
+ // Single dimension tensors.
1266
+ for shape in [ vec ! [ 4 ] , vec ! [ 1 ] ] {
1267
+ let tensor = generate_tensor_from_shape ( & shape) ;
1268
+
1269
+ assert_eq ! ( tensor. image_height_width_channels( ) , None ) ;
1270
+ assert_eq ! ( tensor. shape_short( ) , & tensor. shape( ) [ 0 ..1 ] ) ;
1271
+ assert ! ( tensor. is_vector( ) ) ;
1272
+ assert ! ( !tensor. is_shaped_like_an_image( ) ) ;
1273
+ }
1274
+
1275
+ // Single element, but it might be interpreted as a 1x1 grey image!
1276
+ for shape in [
1277
+ vec ! [ 1 , 1 ] ,
1278
+ vec ! [ 1 , 1 , 1 ] ,
1279
+ vec ! [ 1 , 1 , 1 , 1 ] ,
1280
+ vec ! [ 1 , 1 , 1 , 1 , 1 ] ,
1281
+ ] {
1282
+ let tensor = generate_tensor_from_shape ( & shape) ;
1283
+
1284
+ assert_eq ! ( tensor. image_height_width_channels( ) , Some ( [ 1 , 1 , 1 ] ) ) ;
1285
+ assert_eq ! ( tensor. shape_short( ) , & tensor. shape( ) [ 0 ..1 ] ) ;
1286
+ assert ! ( tensor. is_vector( ) ) ;
1287
+ assert ! ( tensor. is_shaped_like_an_image( ) ) ;
1288
+ }
1289
+ // Color/Grey 2x4 images
1290
+ for shape in [
1291
+ vec ! [ 4 , 2 ] ,
1292
+ vec ! [ 4 , 2 , 1 ] ,
1293
+ vec ! [ 4 , 2 , 1 , 1 ] ,
1294
+ vec ! [ 4 , 2 , 3 ] ,
1295
+ vec ! [ 4 , 2 , 3 , 1 , 1 ] ,
1296
+ vec ! [ 4 , 2 , 4 ] ,
1297
+ vec ! [ 4 , 2 , 4 , 1 , 1 , 1 , 1 ] ,
1298
+ ] {
1299
+ let tensor = generate_tensor_from_shape ( & shape) ;
1300
+ let channels = shape. get ( 2 ) . cloned ( ) . unwrap_or ( 1 ) ;
1301
+
1302
+ assert_eq ! ( tensor. image_height_width_channels( ) , Some ( [ 4 , 2 , channels] ) ) ;
1303
+ assert_eq ! (
1304
+ tensor. shape_short( ) ,
1305
+ & tensor. shape( ) [ 0 ..( 2 + ( channels != 1 ) as usize ) ]
1306
+ ) ;
1307
+ assert ! ( !tensor. is_vector( ) ) ;
1308
+ assert ! ( tensor. is_shaped_like_an_image( ) ) ;
1309
+ }
1310
+
1311
+ // Grey 1x4 images
1312
+ for shape in [
1313
+ vec ! [ 4 , 1 ] ,
1314
+ vec ! [ 4 , 1 , 1 ] ,
1315
+ vec ! [ 4 , 1 , 1 , 1 ] ,
1316
+ vec ! [ 4 , 1 , 1 , 1 , 1 ] ,
1317
+ ] {
1318
+ let tensor = generate_tensor_from_shape ( & shape) ;
1319
+
1320
+ assert_eq ! ( tensor. image_height_width_channels( ) , Some ( [ 4 , 1 , 1 ] ) ) ;
1321
+ assert_eq ! ( tensor. shape_short( ) , & tensor. shape( ) [ 0 ..1 ] ) ;
1322
+ assert ! ( tensor. is_vector( ) ) ;
1323
+ assert ! ( tensor. is_shaped_like_an_image( ) ) ;
1324
+ }
1325
+
1326
+ // Grey 4x1 images
1327
+ for shape in [
1328
+ vec ! [ 1 , 4 ] ,
1329
+ vec ! [ 1 , 4 , 1 ] ,
1330
+ vec ! [ 1 , 4 , 1 , 1 ] ,
1331
+ vec ! [ 1 , 4 , 1 , 1 , 1 ] ,
1332
+ ] {
1333
+ let tensor = generate_tensor_from_shape ( & shape) ;
1334
+
1335
+ assert_eq ! ( tensor. image_height_width_channels( ) , Some ( [ 1 , 4 , 1 ] ) ) ;
1336
+ assert_eq ! ( tensor. shape_short( ) , & tensor. shape( ) [ 0 ..2 ] ) ;
1337
+ assert ! ( tensor. is_vector( ) ) ;
1338
+ assert ! ( tensor. is_shaped_like_an_image( ) ) ;
1339
+ }
1340
+
1341
+ // Non images & non vectors without trailing dimensions
1342
+ for shape in [ vec ! [ 4 , 2 , 5 ] , vec ! [ 1 , 1 , 1 , 2 , 4 ] ] {
1343
+ let tensor = generate_tensor_from_shape ( & shape) ;
1344
+
1345
+ assert_eq ! ( tensor. image_height_width_channels( ) , None ) ;
1346
+ assert_eq ! ( tensor. shape_short( ) , tensor. shape( ) ) ;
1347
+ assert ! ( !tensor. is_vector( ) ) ;
1348
+ assert ! ( !tensor. is_shaped_like_an_image( ) ) ;
1349
+ }
1350
+ }
0 commit comments