|
| 1 | +# Copyright 2024 The swirl_lm Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""A library for communicating information across replicas. |
| 16 | +
|
| 17 | +In an example of 4 replicas, with each replica has data with different sizes as: |
| 18 | + replica 0: data = tf.constant([]) |
| 19 | + replica 1: data = tf.constant([1]) |
| 20 | + replica 2: data = tf.constant([2, 2]) |
| 21 | + replica 3: data = tf.constant([3, 3, 3]) |
| 22 | +If data is shared in an order of 0 -> 1 -> 2 -> 3, the corresponding |
| 23 | +`source_dest_pairs` is [[0, 1], [1, 2], [2, 3]]. With a buffer size `n_max = 3`, |
| 24 | +calling `send_recv(data, source_dest_pairs, n_max)` provides the following: |
| 25 | + replica 0: tf.constant([0, 0, 0]) |
| 26 | + replica 1: tf.constant([]) |
| 27 | + replica 2: tf.constant([1]) |
| 28 | + replica 3: tf.constant([2, 2]). |
| 29 | +
|
| 30 | +Note that in the example above, the `source_dest_pairs` can be obtained by |
| 31 | +calling `source_dest_pairs_along_dim(np.array([[[0]], [[1]], [[2]], [[3]]]), 0, |
| 32 | +True, False)`, |
| 33 | +or `source_dest_pairs_along_dim(np.array([[[0]], [[1]], [[2]], [[3]]]) |
| 34 | +*parse_dim('+x'))`. |
| 35 | +""" |
| 36 | + |
| 37 | +import re |
| 38 | + |
| 39 | +import numpy as np |
| 40 | +import tensorflow as tf |
| 41 | + |
| 42 | + |
| 43 | +def parse_dim(dim_info: str) -> tuple[int, bool, bool]: |
| 44 | + """Parses a dimension string into a tuple (dim, forward, periodic). |
| 45 | +
|
| 46 | + Args: |
| 47 | + dim_info: A string that has a structure '[-+][xyz]p?$'. The first character |
| 48 | + is '-' or '+', which indicates the negative or positive direction, |
| 49 | + respectively. The second character is one of 'x', 'y', and 'z', which |
| 50 | + corresponds to dimension 0, 1, and 2, respectively. The optional last |
| 51 | + character is 'p', which suggests the dimension is periodic if present. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + A 3-element tuple, with the first element being the dimension, the second |
| 55 | + indicating whether the dimension is along the positive direction, and the |
| 56 | + third indicating whether the dimension is periodic. |
| 57 | +
|
| 58 | + Raises: |
| 59 | + ValueError if `dim_info` does not match '[-+][xyz]p?$'. |
| 60 | + """ |
| 61 | + m = re.fullmatch(r'([-+])([xyz])(p?)', dim_info) |
| 62 | + if m is None: |
| 63 | + raise ValueError( |
| 64 | + f'{dim_info} does not conform with the string structure for dimension' |
| 65 | + ' info ("[-+][xyz]p?$").' |
| 66 | + ) |
| 67 | + |
| 68 | + dim = 'xyz'.index(m.group(2)) |
| 69 | + forward = m.group(1) == '+' |
| 70 | + periodic = m.group(3) == 'p' |
| 71 | + |
| 72 | + return dim, forward, periodic |
| 73 | + |
| 74 | + |
| 75 | +def source_dest_pairs_along_dim( |
| 76 | + replicas: np.ndarray, dim: int, forward: bool, periodic: bool |
| 77 | +) -> np.ndarray: |
| 78 | + """Generates a 2-D array of source-target pairs along `dim` in the topology. |
| 79 | +
|
| 80 | + Args: |
| 81 | + replicas: A 3-D tensor representing the topology of the partitions. |
| 82 | + dim: The dimension of communication. Should be one of 0, 1, and 2. |
| 83 | + forward: A boolean argument that indicates sending data from replicas with |
| 84 | + lower indices to higher indices along the positive direction of the |
| 85 | + topology. If it is `False`, communication in performed the opposite |
| 86 | + direction, i.e. from the higher indices to lower indices. |
| 87 | + periodic: An indicator of whether the topology is periodic. When using the |
| 88 | + `source_dest_pairs` generated with this function, if `periodic` is |
| 89 | + `True`, data from the last replica along `dim` will be send to the first |
| 90 | + replica; otherwise the first replica returns all zeros with the same size |
| 91 | + as the input. The first and last replica follows the direction specified |
| 92 | + in `dim`. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + A 2-D array of size `[num_pairs, 2]`, with the columns being the |
| 96 | + `replica_id` of the senders and the receivers, respectively. |
| 97 | + """ |
| 98 | + rolled = np.roll(replicas, -1 if forward else 1, axis=dim) |
| 99 | + trim = slice( |
| 100 | + None if periodic or forward else 1, |
| 101 | + None if periodic or not forward else -1, |
| 102 | + ) |
| 103 | + stacked = np.moveaxis(np.stack([replicas, rolled]), dim + 1, 1)[:, trim] |
| 104 | + return np.reshape(stacked, (2, -1)).T |
| 105 | + |
| 106 | + |
| 107 | +def send_recv( |
| 108 | + data: tf.Tensor, source_dest_pairs: np.ndarray, n_max: int |
| 109 | +) -> tf.Tensor: |
| 110 | + """Exchanges N-D `tf.Tensor`s across a list of (sender, receiver) pairs. |
| 111 | +
|
| 112 | + Args: |
| 113 | + data: The n-dimensional tensor to be sent to a different replica. Dimension |
| 114 | + 0 of this tensor can have different sizes across replicas |
| 115 | + source_dest_pairs: A 2-D numpy array of shape `[num_replicas, 2]`, with the |
| 116 | + first column being the senders' `replica_id`, and the second one being the |
| 117 | + receiver's `replica_id`. |
| 118 | + n_max: The buffer size for the communication. It has to be greater or equal |
| 119 | + to the maximum number of `data.shape[0]` across all replicas, otherwise a |
| 120 | + runtime error will occur while padding the buffer for communication. |
| 121 | +
|
| 122 | + Returns: |
| 123 | + An N-D tensor received from the sender replica specified in |
| 124 | + `source_dest_pairs`. |
| 125 | + """ |
| 126 | + # Because `CollectivePermute` permits transferring data that has the same |
| 127 | + # shape across all replicas only, we need to pad the input data to satisfy |
| 128 | + # this condition. |
| 129 | + static_shape = data.get_shape() |
| 130 | + u = tf.scatter_nd( |
| 131 | + tf.range(tf.shape(data)[0])[:, tf.newaxis], |
| 132 | + data, |
| 133 | + (n_max, *static_shape[1:]), |
| 134 | + ) |
| 135 | + |
| 136 | + n_received = tf.raw_ops.CollectivePermute( |
| 137 | + input=tf.shape(data)[0], source_target_pairs=source_dest_pairs |
| 138 | + ) |
| 139 | + w = tf.raw_ops.CollectivePermute( |
| 140 | + input=u, source_target_pairs=source_dest_pairs |
| 141 | + ) |
| 142 | + # Here we trim the padded data back to its original size. |
| 143 | + return tf.gather_nd(w, tf.where(tf.range(n_max) < n_received)) |
0 commit comments