Skip to content

Commit

Permalink
Enforce static dimensions in generation of flow.tensor.transfer (#205)
Browse files Browse the repository at this point in the history
This solves the problem in iree-org/iree#18283
The issue is that we generate cast to/from dynamic tensors that later
lowering in IREE chokes on it. My assumption is that it should be able
to digest this IR since it is of the form.

```mlir
    %2 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3,11,13],f32> -> tensor<2x3x11x13xf32>
    %cast = tensor.cast %2 : tensor<2x3x11x13xf32> to tensor<?x?x?x?xf32>
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %cast, %c0 : tensor<?x?x?x?xf32>
    %c1 = arith.constant 1 : index
    %dim_0 = tensor.dim %cast, %c1 : tensor<?x?x?x?xf32>
    %c2 = arith.constant 2 : index
    %dim_1 = tensor.dim %cast, %c2 : tensor<?x?x?x?xf32>
    %c3 = arith.constant 3 : index
    %dim_2 = tensor.dim %cast, %c3 : tensor<?x?x?x?xf32>
    %3 = flow.tensor.transfer %cast : tensor<?x?x?x?xf32>{%dim, %dim_0, %dim_1, %dim_2} to #hal.device.promise<@__device_0>
    %cast_3 = tensor.cast %3 : tensor<?x?x?x?xf32> to tensor<2x3x11x13xf32>
    %4 = torch_c.from_builtin_tensor %cast_3 : tensor<2x3x11x13xf32> -> !torch.vtensor<[2,3,11,13],f32>
```
It essentially casts to a dynamic `tensor<...>` for the purpose of
performing `flow.tensor.transfer` and then casts back to a static
`torch.vtensor`. So it should be fine.

With this change we get
```mlir
    %2 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[2,3,11,13],f32> -> tensor<2x3x11x13xf32>
    %3 = flow.tensor.transfer %2 : tensor<2x3x11x13xf32> to #hal.device.promise<@__device_0>
    %4 = torch_c.from_builtin_tensor %3 : tensor<2x3x11x13xf32> -> !torch.vtensor<[2,3,11,13],f32>
```

Signed-off-by: Boian Petkantchin <[email protected]>
  • Loading branch information
sogartar authored Oct 9, 2024
1 parent 351f2fe commit 586b9af
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion iree/turbine/ops/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ class transfer_to_logical_device(CustomOp):
def select(self, ksel: KernelSelection):
ksel.attr_str(0)
ta = ksel.arg_tensor(1)
ksel.return_tensor(ta.t)
ta.specialize_all_dims()
ksel.return_tensor(ta.t).specialize_all_dims()

def eager_execute(self, device_moniker, tensor):
return tensor
Expand Down

0 comments on commit 586b9af

Please sign in to comment.