@@ -483,18 +483,7 @@ def main(A: R.Tensor([4096], "float32")):
483483 class Expected :
484484 @R .function
485485 def main (A : R .Tensor ([4096 ], "float32" )):
486- B = R .ExternFunc (
487- "runtime.TVMArrayCreateView" ,
488- R .Callable (
489- derive_func = "tvm.relax.struct_info.infer_view_sinfo" ,
490- purity = True ,
491- ),
492- )(
493- A ,
494- R .shape ([64 , 64 ]),
495- R .dtype ("float32" ),
496- R .prim_value (0 ),
497- )
486+ B = R .memory .view (A , shape = R .shape ([64 , 64 ]), dtype = "float32" , relative_byte_offset = 0 )
498487 return B
499488
500489 After = tvm .relax .transform .LegalizeOps ()(Before )
@@ -515,18 +504,7 @@ def main(A: R.Tensor(dtype="float32")):
515504 class Expected :
516505 @R .function
517506 def main (A : R .Tensor (dtype = "float32" )):
518- B = R .ExternFunc (
519- "runtime.TVMArrayCreateView" ,
520- R .Callable (
521- derive_func = "tvm.relax.struct_info.infer_view_sinfo" ,
522- purity = True ,
523- ),
524- )(
525- A ,
526- R .shape ([64 , 64 ]),
527- R .dtype ("float32" ),
528- R .prim_value (0 ),
529- )
507+ B = R .memory .view (A , shape = R .shape ([64 , 64 ]), dtype = "float32" , relative_byte_offset = 0 )
530508 return B
531509
532510 After = tvm .relax .transform .LegalizeOps ()(Before )
@@ -545,17 +523,8 @@ def main(A: R.Tensor([4096], "float32")):
545523 class Expected :
546524 @R .function
547525 def main (A : R .Tensor ([4096 ], "float32" )):
548- B = R .ExternFunc (
549- "runtime.TVMArrayCreateView" ,
550- R .Callable (
551- derive_func = "tvm.relax.struct_info.infer_view_sinfo" ,
552- purity = True ,
553- ),
554- )(
555- A ,
556- R .shape ([4096 ]),
557- R .dtype ("int32" ),
558- R .prim_value (0 ),
526+ B = R .memory .view (
527+ A , dtype = R .dtype ("int32" ), shape = R .shape ([4096 ]), relative_byte_offset = 0
559528 )
560529 return B
561530
@@ -575,17 +544,8 @@ def main(A: R.Tensor([4096], "float32")):
575544 class Expected :
576545 @R .function
577546 def main (A : R .Tensor ([4096 ], "float32" )):
578- B = R .ExternFunc (
579- "runtime.TVMArrayCreateView" ,
580- R .Callable (
581- derive_func = "tvm.relax.struct_info.infer_view_sinfo" ,
582- purity = True ,
583- ),
584- )(
585- A ,
586- R .shape ([4096 ]),
587- R .dtype ("float32" ),
588- R .prim_value (0 ),
547+ B = R .memory .view (
548+ A , relative_byte_offset = R .prim_value (0 ), shape = R .shape ([4096 ]), dtype = "float32"
589549 )
590550 return B
591551
@@ -624,29 +584,17 @@ def main(A: R.Tensor([4096], "uint8")):
624584 class Expected :
625585 @R .function
626586 def main (A : R .Tensor ([4096 ], "uint8" )):
627- B = R .ExternFunc (
628- "runtime.TVMArrayCreateView" ,
629- R .Callable (
630- derive_func = "tvm.relax.struct_info.infer_view_sinfo" ,
631- purity = True ,
632- ),
633- )(
587+ B = R .memory .view (
634588 A ,
635- R .shape ([512 ]),
636- R .dtype ("int32" ),
637- R .prim_value (0 ),
589+ shape = R .shape ([512 ]),
590+ dtype = R .dtype ("int32" ),
591+ relative_byte_offset = R .prim_value (0 ),
638592 )
639- C = R .ExternFunc (
640- "runtime.TVMArrayCreateView" ,
641- R .Callable (
642- derive_func = "tvm.relax.struct_info.infer_view_sinfo" ,
643- purity = True ,
644- ),
645- )(
593+ C = R .memory .view (
646594 A ,
647- R .shape ([16 , 64 ]),
648- R .dtype ("float16" ),
649- R .prim_value (2048 ),
595+ shape = R .shape ([16 , 64 ]),
596+ dtype = R .dtype ("float16" ),
597+ relative_byte_offset = R .prim_value (2048 ),
650598 )
651599 return (B , C )
652600
@@ -772,5 +720,30 @@ def main(A: R.Tensor([4096], "uint8")):
772720 tvm .testing .assert_allclose (tvm_output [1 ].numpy (), np_expected [1 ])
773721
774722
723+ @tvm .testing .parametrize_targets ("llvm" , "cuda" )
724+ def test_execute_view_with_new_byte_offset_ensure_aligned (target , dev ):
725+ @I .ir_module
726+ class Module :
727+ @R .function
728+ def main (A : R .Tensor ([4096 ], "float32" )):
729+ B = R .memory .view (
730+ A ,
731+ shape = R .shape ([16 , 64 ]),
732+ relative_byte_offset = 32 * 64 * 4 ,
733+ )
734+ C = R .memory .ensure_aligned (B )
735+ return C
736+
737+ built = tvm .relax .build (Module , target = target )
738+ vm = tvm .relax .VirtualMachine (built , device = dev )
739+
740+ np_input = np .random .random ([4096 ]).astype ("float32" )
741+ tvm_input = tvm .nd .array (np_input , dev )
742+ tvm_output = vm ["main" ](tvm_input )
743+ np_expected = np_input .reshape (64 , 64 )[32 :48 , :]
744+
745+ tvm .testing .assert_allclose (tvm_output .numpy (), np_expected )
746+
747+
775748if __name__ == "__main__" :
776749 tvm .testing .main ()
0 commit comments