Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enforce static dimensions in generation of flow.tensor.transfer (#205)
This solves the problem in iree-org/iree#18283 The issue is that we generate cast to/from dynamic tensors that later lowering in IREE chokes on it. My assumption is that it should be able to digest this IR since it is of the form. ```mlir %2 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3,11,13],f32> -> tensor<2x3x11x13xf32> %cast = tensor.cast %2 : tensor<2x3x11x13xf32> to tensor<?x?x?x?xf32> %c0 = arith.constant 0 : index %dim = tensor.dim %cast, %c0 : tensor<?x?x?x?xf32> %c1 = arith.constant 1 : index %dim_0 = tensor.dim %cast, %c1 : tensor<?x?x?x?xf32> %c2 = arith.constant 2 : index %dim_1 = tensor.dim %cast, %c2 : tensor<?x?x?x?xf32> %c3 = arith.constant 3 : index %dim_2 = tensor.dim %cast, %c3 : tensor<?x?x?x?xf32> %3 = flow.tensor.transfer %cast : tensor<?x?x?x?xf32>{%dim, %dim_0, %dim_1, %dim_2} to #hal.device.promise<@__device_0> %cast_3 = tensor.cast %3 : tensor<?x?x?x?xf32> to tensor<2x3x11x13xf32> %4 = torch_c.from_builtin_tensor %cast_3 : tensor<2x3x11x13xf32> -> !torch.vtensor<[2,3,11,13],f32> ``` It essentially casts to a dynamic `tensor<...>` for the purpose of performing `flow.tensor.transfer` and then casts back to a static `torch.vtensor`. So it should be fine. With this change we get ```mlir %2 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3,11,13],f32> -> tensor<2x3x11x13xf32> %3 = flow.tensor.transfer %2 : tensor<2x3x11x13xf32> to #hal.device.promise<@__device_0> %4 = torch_c.from_builtin_tensor %3 : tensor<2x3x11x13xf32> -> !torch.vtensor<[2,3,11,13],f32> ``` Signed-off-by: Boian Petkantchin <[email protected]>
- Loading branch information