From 3984cffa3e8660dfeba232d5f3c2d84556790b05 Mon Sep 17 00:00:00 2001 From: Isha Arkatkar Date: Fri, 20 Sep 2024 19:26:35 -0700 Subject: [PATCH] Introduce Pad Op before tf.Split and Slice op after tf.concat ops to handle not-divisible sharding for SPMD. This is an alternate approach to using XLA ND Split/Concat ops. tf.Split and tf.Concat ops operate on a single dimension at a time. So the padding and slice ops are introduced accordingly. PiperOrigin-RevId: 677052836 --- .../mlir/tensorflow/tests/tpu_rewrite.mlir | 178 +++++++++++++++- .../tensorflow/utils/xla_sharding_util.cc | 199 +++++++++++++----- .../utils/xla_sharding_util_test.cc | 25 ++- 3 files changed, 344 insertions(+), 58 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 07e3eea2d4ea69..e7bd2191b344f5 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1891,19 +1891,19 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // Tests tile sharding of inputs with number of splits that does not evenly divide -// the input results in an error. +// the input results in an error, when shapes are not fully known. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { - func.func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { - %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + func.func @uneven_input_sharding_disallowed(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> } - func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { - %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + func.func @tpu0_func(%arg0: tensor, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor) -> (tensor<*xi32>, tensor<*xi1>) %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1> func.return %4, %3 : tensor<*xi32>, tensor<*xi1> @@ -2839,3 +2839,169 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func.return %4, %3 : tensor<*xi32>, tensor<*xi1> } } + +// ----- + +// Tests that outputs are correctly merged and fed from TPU computation for +// tiled output sharding with padding for concat ops. + +// The following OpSharding is used for TPU computation outputs in below test: +// Proto debug string: +// output 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// output 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_output + func.func @parallel_execute_with_tiled_output(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<128x10xi32>, %arg3: tensor<128x10xi32>) -> (tensor<128x5xi32>, tensor<10x5xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<128x10xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<128x10xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}> + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}> + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + // + // CHECK: %[[CONST_CONCAT3_DIM:.*]] = "tf.Const"() + // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2) + // CHECK: %[[CONST_SLICE_BEGIN:.*]] = "tf.Const"() + // dense<0> + // tensor<2xi64>}> : () -> tensor<2xi64> + // CHECK: %[[CONST_SLICE_SIZE:.*]] = "tf.Const"() + // dense<[128, 5]> : tensor<2xi64>}> : () -> tensor<2xi64> + // CHECK: "tf.Slice"(%[[CONCAT3_OUTPUT]], %[[CONST_SLICE_BEGIN]], %[[CONST_SLICE_SIZE]]) + // : (tensor<128x6xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<128x5xi32> + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<128x10xi32>) -> (tensor<128x5xi32>, tensor<10x5xi1>) + tf_device.return %1, %2 : tensor<128x5xi32>, tensor<10x5xi1> + } + func.return %0#0, %1#0 : tensor<128x5xi32>, tensor<10x5xi1> + } + func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<128x10xi32>, tensor<128x10xi32>) -> (tensor<128x10xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<10x5xi1>) -> tensor<10x5xi1> + func.return %4, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } +} + +// ----- + +// Tests inputs are correctly split and fed into TPU computation for tiled input +// sharding with padding. + +// The following OpSharding is used for TPU computation inputs in the below +// test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_input + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x9xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x9xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<128x10xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<128x10xi32>) + func.func @parallel_execute_with_tiled_input(%arg0: tensor<128x9xf32>, %arg1: tensor<128x9xf32>, %arg2: tensor<128x10xi32>, %arg3: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x9xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<128x10xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x9xf32>, [%arg2, %arg3] as %ri_2: tensor<128x10xi32>) {n = 2 : i32} { + // CHECK: %[[DEVICE_LAUNCH_OUT:[a-z0-9]+]] = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // + // CHECK: %[[PAD_SHAPE:[a-z0-9]+]] = "tf.Const"() + // CHECK: [0, 0], [0, 1] + // CHECK: : tensor<2x2xi64>}> : () -> tensor<2x2xi64> + // CHECK: %[[PAD_OUT:[a-z0-9]+]] = "tf.Pad"(%[[DEVICE_LAUNCH_OUT]], %[[PAD_SHAPE]]) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x9xf32>, tensor<2x2xi64>) -> tensor<128x10xf32> + // CHECK: %[[CONST_SPLIT_DIM:.*]] = "tf.Const"() <{value = dense<1> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor + // CHECK: %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[PAD_OUT]]) {ici_weight_distribution_mlir_bridge_marker = true, num_split = 2 : i32} : (tensor, tensor<128x10xf32>) -> (tensor<128x5xf32>, tensor<128x5xf32>) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}> + // + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1) + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}> + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + %1 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ + %identity = "tf.Identity"(%ri_1) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x9xf32>) -> tensor<128x9xf32> + tf_device.return %identity : tensor<128x9xf32> + }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x9xf32> + %2, %3 = "tf_device.cluster_func"(%1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x9xf32>, tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + tf_device.return %2, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } + func.return %0#0, %1#0 : tensor<128x10xi32>, tensor<10x5xi1> + } + func.func @tpu0_func(%arg0: tensor<128x9xf32>, %arg1: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x9xf32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<128x10xi32>, tensor<128x10xi32>) -> (tensor<128x10xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<10x5xi1>) -> tensor<10x5xi1> + func.return %4, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } +} + +// ----- + +// CHECK: "tf.Split" +// : (tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) +module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1857 : i32}} { + func.func @main(%arg0: tensor {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} { + %0 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster__train_helper", device = ""} : () -> tensor + %1 = "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> + %3:2 = tf_device.replicate {n = 2 : i32} { + %6 = "tf_device.cluster_func"(%1, %2) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], input_sharding_configuration = ["\08\03\1A\01\04\22\04\00\01\02\03", ""], num_cores_per_replica = 4 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> + tf_device.return %6 : tensor<*xf32> + } + %4 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor + } + func.func private @_func(%arg0: tensor<128x1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}, %arg1: tensor<1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = ""}) -> (tensor<*xf32> {mhlo.sharding = ""}) { + %0 = "tf.XlaSharding"(%arg0) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + %1 = "tf.MatMul"(%0, %arg1) : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> + return %1 : tensor<*xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index d10b908e02d3c3..8913a1812b9c99 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -73,15 +74,93 @@ int64_t GetPadding(const int split_dim, const int num_splits, return total_padding; } +mlir::TF::SliceOp CreateSliceOp(mlir::OpBuilder* builder, + const mlir::Location& location, + mlir::Value input, + const PartialTensorShape& shape) { + mlir::SmallVector slice_start_position; + for (int i = 0; i < shape.dims(); ++i) { + slice_start_position.push_back(0); + } + mlir::SmallVector slice_size; + for (int i = 0; i < shape.dims(); ++i) { + slice_size.push_back(shape.dim_size(i)); + } + + auto start_position_type = + mlir::RankedTensorType::get(shape.dims(), builder->getIntegerType(64)); + + auto start_position_op = builder->create( + input.getLoc(), mlir::DenseIntElementsAttr::get(start_position_type, + slice_start_position)); + + auto slice_size_op = builder->create( + input.getLoc(), mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get( + shape.dims(), builder->getIntegerType(64)), + slice_size)); + + auto slice_result_type = + mlir::RankedTensorType::get(slice_size, getElementTypeOrSelf(input)); + + return builder->create(input.getLoc(), slice_result_type, + input, start_position_op, + slice_size_op); +} + +mlir::TF::PadOp CreatePadOp(mlir::OpBuilder* builder, + const mlir::Location& location, int64_t num_dims, + int64_t split_dim, mlir::Value src_input, + int64_t padding) { + auto input_type = mlir::cast(src_input.getType()); + llvm::SmallVector padding_values; + std::vector padded_shape; + for (int i = 0; i < num_dims; ++i) { + // 0 padding in the beginning. + padding_values.push_back(0); + if (i == split_dim) { + // pad the split dimension to make the total size of the input equal to + // the total size of the split dimension. + padding_values.push_back(padding); + padded_shape.push_back(input_type.getShape()[i] + padding); + } else { + padding_values.push_back(0); + padded_shape.push_back(input_type.getShape()[i]); + } + } + auto padding_type = + mlir::RankedTensorType::get({num_dims, 2}, builder->getIntegerType(64)); + auto paddings = mlir::DenseIntElementsAttr::get(padding_type, padding_values); + auto paddings_value = builder->create(location, paddings); + mlir::SmallVector expand_shape(padded_shape.begin(), + padded_shape.end()); + + auto expand_result_type = + mlir::RankedTensorType::get(expand_shape, input_type.getElementType()); + + return builder->create(location, expand_result_type, + src_input, paddings_value); +} + // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways // in 'split_dimension' dimension and returns the split values. -mlir::LogicalResult CreateSplitOp(const int num_split, - const int split_dimension, - const mlir::Location& location, - mlir::Value src_input, - mlir::OpBuilder* builder, - mlir::TF::SplitOp* split_op, - bool is_ici_weight_dist_spmd) { +mlir::LogicalResult CreateSplitOp( + const int num_split, const int split_dimension, const int64_t padding, + const mlir::Location& location, mlir::Value src_input, + mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op, + bool is_ici_weight_dist_spmd) { + if (padding > 0) { + int64_t num_dims = + mlir::cast(src_input.getType()).getRank(); + auto pad_op = CreatePadOp(builder, location, num_dims, split_dimension, + src_input, padding); + if (is_ici_weight_dist_spmd) { + pad_op->setAttr(kICIWeightDistributionMlirBridgeMarker, + builder->getBoolAttr(true)); + } + src_input = pad_op.getResult(); + } + // Creates a const op to hold split dimension value. auto split_dim_type = mlir::RankedTensorType::get({}, builder->getIntegerType(32)); @@ -139,6 +218,7 @@ mlir::LogicalResult CreateSplitOp(const int num_split, // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`. mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension, const mlir::Location& location, + const int64_t padding, mlir::ArrayRef inputs, mlir::OpBuilder* builder) { // Creates a const op to hold concat dimension value. @@ -265,6 +345,22 @@ mlir::LogicalResult CreateXlaSplitNDOp(const mlir::Location& location, return mlir::success(); } +bool IsShapeKnown(mlir::TensorType type) { + if (!type.hasRank()) return false; + + bool shape_known = false; + for (int i = 0; i < type.getRank(); ++i) { + if (type.getShape()[i] == mlir::ShapedType::kDynamic) { + shape_known = false; + break; + } else { + shape_known = true; + } + } + + return shape_known; +} + mlir::LogicalResult HandleTileShardedInputsUsingXlaSplitOps( const mlir::Location& location, const xla::OpSharding& input_sharding, const mlir::Value& original_source, mlir::OpBuilder* builder, @@ -335,17 +431,27 @@ mlir::LogicalResult HandleTileShardedInputsUsingTfSplitOps( LOG(ERROR) << dimension_to_splits_map.status(); return mlir::failure(); } - + PartialTensorShape shape; + const auto input_type = + mlir::cast(original_source.getType()); + bool input_shape_known = IsShapeKnown(input_type); + if (input_shape_known) { + shape = PartialTensorShape(input_type.getShape()); + } for (const auto& dimension_and_num_splits : *dimension_to_splits_map) { const int dimension = dimension_and_num_splits.first; const int num_splits = dimension_and_num_splits.second; + int padding = input_shape_known + ? GetPadding(dimension, num_splits, + PartialTensorShape(input_type.getShape())) + : 0; // Creates root split op. if (split_ops_for_tiled_input.empty()) { mlir::TF::SplitOp root_split_op; - auto result = - CreateSplitOp(num_splits, dimension, location, original_source, - builder, &root_split_op, is_ici_weight_dist_spmd); + auto result = CreateSplitOp(num_splits, dimension, padding, location, + original_source, builder, &root_split_op, + is_ici_weight_dist_spmd); if (mlir::failed(result)) return mlir::failure(); split_ops_for_tiled_input.emplace_back(root_split_op); @@ -358,7 +464,7 @@ mlir::LogicalResult HandleTileShardedInputsUsingTfSplitOps( for (auto split_op : split_ops_for_tiled_input) { for (auto parent_split_output_value : split_op.getResults()) { mlir::TF::SplitOp child_split_op; - auto result = CreateSplitOp(num_splits, dimension, location, + auto result = CreateSplitOp(num_splits, dimension, padding, location, parent_split_output_value, builder, &child_split_op, is_ici_weight_dist_spmd); if (mlir::failed(result)) return mlir::failure(); @@ -827,7 +933,15 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( LOG(ERROR) << dimension_to_splits_map.status(); return mlir::failure(); } - + auto output_type = + mlir::cast(cluster_func_output.getType()); + PartialTensorShape shape; + bool output_shape_known = IsShapeKnown(output_type); + if (output_shape_known) { + shape = PartialTensorShape(output_type.getShape()); + } + bool has_paddings = false; + std::vector paddings; for (auto it = dimension_to_splits_map->rbegin(); it != dimension_to_splits_map->rend(); ++it) { int concat_dimension = it->first; @@ -837,12 +951,21 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( new_outputs.reserve(num_splits); for (int i = 0, end = outputs_to_merge.size(); i < end; i = i + num_splits) { + int64_t padding; + if (output_shape_known) { + padding = GetPadding(concat_dimension, num_splits, shape); + } else { + padding = 0; + } mlir::TF::ConcatOp concat_op = - CreateConcatOp(concat_dimension, location, + CreateConcatOp(concat_dimension, location, padding, llvm::ArrayRef{ outputs_to_merge.begin() + i, outputs_to_merge.begin() + i + num_splits}, builder); + + paddings.push_back(padding); + has_paddings |= padding > 0; new_outputs.emplace_back(concat_op.getResult()); } @@ -850,6 +973,12 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( } assert(outputs_to_merge.size() == 1); + if (has_paddings) { + // Add slice op to remove paddings. + mlir::TF::SliceOp slice_op = + CreateSliceOp(builder, location, outputs_to_merge[0], shape); + cluster_func_output.replaceAllUsesWith(slice_op.getResult()); + } cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]); return mlir::success(); } @@ -876,26 +1005,13 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( *tiled_logical_computation_type = cluster_func_output_type; break; } - if (use_xla_nd_ops) { - if (output_shape[dimension] % output_splits == 0) { - new_output_shape[dimension] = output_shape[dimension] / output_splits; - } else { - // Input will be padded to be divisible by output_splits, thus add 1 to - // the output shape. - new_output_shape[dimension] = - (output_shape[dimension] / output_splits) + 1; - } - } else { - if (output_shape[dimension] % output_splits != 0) { - mlir::emitError( - location, - llvm::formatv("incorrect output sharding received. " - "{0}-th dimension of the output must be " - "evenly divisible by {1}, got dimension " - "shape {2}", - dimension, output_splits, output_shape[dimension])); - } + if (output_shape[dimension] % output_splits == 0) { new_output_shape[dimension] = output_shape[dimension] / output_splits; + } else { + // Input will be padded to be divisible by output_splits, thus add 1 to + // the output shape. + new_output_shape[dimension] = + (output_shape[dimension] / output_splits) + 1; } } @@ -904,23 +1020,6 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( return mlir::success(); } - -bool IsShapeKnown(mlir::TensorType type) { - if (!type.hasRank()) return false; - - bool shape_known = false; - for (int i = 0; i < type.getRank(); ++i) { - if (type.getShape()[i] == mlir::ShapedType::kDynamic) { - shape_known = false; - break; - } else { - shape_known = true; - } - } - - return shape_known; -} - } // namespace bool AreInputOutputShapesStaticallyKnownForSplitSharding( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc index 84d5697c9a6c2b..a168ad9984041e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc @@ -139,7 +139,6 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { func.func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { - // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } @@ -165,6 +164,7 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { int num_cores_per_replica = 4; mlir::OpBuilder builder(&context); bool use_xla_nd_ops = true; + llvm::SmallVector, 4> input_list; auto result = tensorflow::ExtractInputsForLogicalDevices( num_cores_per_replica, cluster_func_op, &builder, use_xla_nd_ops, @@ -194,9 +194,30 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { // will appropriately add the values to the block. op->destroy(); + input_list.clear(); + // Expect error when use_xla_nd_ops is false. result = tensorflow::ExtractInputsForLogicalDevices( num_cores_per_replica, cluster_func_op, &builder, false, &input_list); - ASSERT_TRUE(failed(result)); + ASSERT_TRUE(succeeded(result)); + auto* split_op = input_list.front().front().getDefiningOp(); + ASSERT_TRUE(mlir::isa(split_op)); + + llvm::SmallVector split_inputs(split_op->getOperands()); + // Constant op for the split dimension + auto* const_op = split_inputs[0].getDefiningOp(); + ASSERT_TRUE(mlir::isa(const_op)); + // Pad op for the padding value to make it divisible by num_splits. + auto* pad_op = split_inputs[1].getDefiningOp(); + ASSERT_TRUE(mlir::isa(pad_op)); + llvm::SmallVector pad_inputs(pad_op->getOperands()); + auto* const_pad_value = pad_inputs[1].getDefiningOp(); + ASSERT_TRUE(mlir::isa(const_pad_value)); + // Destroy the ops to avoid error during block deletion (Same as above): + // use_empty() && "Cannot destroy a value that still has uses!" + split_op->destroy(); + const_op->destroy(); + pad_op->destroy(); + const_pad_value->destroy(); } } // namespace