Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pmarkovicTT committed Jan 13, 2025
1 parent c8b7156 commit f846f7f
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ module attributes {tt.device = #device, tt.system_desc = #system_desc} {
func.func @backward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>, %arg2: tensor<32x32x128xf32, #ttnn_layout2>) -> tensor<32x32x128xf32, #ttnn_layout3> {
%0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
// CHECK: %[[DEVICE_OP:.*]] = "ttnn.get_device"
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout3>
%2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout4>
%3 = "ttnn.to_layout"(%arg2, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<32x32x128xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout5>
%4 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout5>
%1 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout4>
%2 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout5>
%3 = "ttnn.to_layout"(%arg2, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<32x32x128xf32, #ttnn_layout2>, !tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout6>
%4 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout7>
// CHECK: %[[EMPTY_OP:.*]] = "ttnn.empty"(%[[DEVICE_OP]])
// Check that the input operand is transformed into the row major layout.
// CHECK-NEXT: %[[TO_LAYOUT_INPUT:.*]] = "ttnn.to_layout"
Expand Down

0 comments on commit f846f7f

Please sign in to comment.