-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TTNN] Adding support for data type workarounds and introducing Embed…
…ding workarounds (#1583) This PR introduces a solution for handling data type workarounds for operation operands and results. To address input operand data type workarounds, we insert a `toLayout` operation between the input operands and the operation itself. This casts the input to the desired data type. If the data type of the output result changes due to a workaround, we will revert it to the previous data type by inserting a `ToLayoutOp` after the operation's output. Additionally, this PR provides necessary workarounds to ensure that the embedding operation functions correctly. Specifically, it changes the input to an RM layout and casts both the input weight and the output to bf16. Other ops will be onboarded to this type of workaround in a separate PR. Example of IR today: ```mlir module attributes {tt.device = #device, tt.system_desc = #system_desc} { func.func @forward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device> %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout3> %2 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<32x32>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout4> %3 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xf32, #ttnn_layout5> %4 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>}> : (tensor<32x32x128xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout6> %5 = "ttnn.embedding"(%2, %3, %4) : (tensor<32x32xf32, #ttnn_layout4>, tensor<512x128xf32, #ttnn_layout5>, tensor<32x32x128xf32, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout6> %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>>}> : (tensor<32x32x128xf32, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout2> return %6 : tensor<32x32x128xf32, #ttnn_layout2> } } ``` An example of IR with this change where embedding op has bf16 workaround applied for weight operand: ```mlir module attributes {tt.device = #device, tt.system_desc = #system_desc} { func.func @forward(%arg0: tensor<32x32xf32, #ttnn_layout>, %arg1: tensor<512x128xf32, #ttnn_layout1>) -> tensor<32x32x128xf32, #ttnn_layout2> { %0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device> %1 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>, shape = #ttnn.shape<32x32x128>}> : (!tt.device<#device>) -> tensor<32x32x128xf32, #ttnn_layout3> %2 = "ttnn.to_layout"(%arg0, %0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#dram, <<32x32>>, <interleaved>>}> : (tensor<32x32xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<32x32xf32, #ttnn_layout4> %3 = "ttnn.to_layout"(%arg1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x4>>, <interleaved>>}> : (tensor<512x128xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x128xbf16, #ttnn_layout5> %4 = "ttnn.to_layout"(%1, %0) <{dtype = #tt.supportedDataTypes<bf16>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<32x4>>, <interleaved>>}> : (tensor<32x32x128xf32, #ttnn_layout3>, !tt.device<#device>) -> tensor<32x32x128xbf16, #ttnn_layout6> %5 = "ttnn.embedding"(%2, %3, %4) : (tensor<32x32xf32, #ttnn_layout4>, tensor<512x128xbf16, #ttnn_layout5>, tensor<32x32x128xbf16, #ttnn_layout6>) -> tensor<32x32x128xbf16, #ttnn_layout6> %6 = "ttnn.to_layout"(%5) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<row_major>, memory_config = #ttnn.memory_config<#system_memory, <<1024x128>>>}> : (tensor<32x32x128xbf16, #ttnn_layout6>) -> tensor<32x32x128xf32, #ttnn_layout2> return %6 : tensor<32x32x128xf32, #ttnn_layout2> } } ``` - Closes #1433 - Closes #1497 - Closes #1215
- Loading branch information
1 parent
109d917
commit 9520cbb
Showing
10 changed files
with
168 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.