Skip to content

Commit 04ceee7

Browse files
authored
Fix distributed gather for tuples of tensors of varying sizes (#11071)
1 parent f05a8a0 commit 04ceee7

File tree

2 files changed

+52
-22
lines changed

2 files changed

+52
-22
lines changed

src/transformers/trainer_pt_utils.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,8 @@ def nested_new_like(arrays, num_samples, padding_index=-100):
276276
return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
277277

278278

279-
def nested_expand_like(arrays, new_seq_length, padding_index=-100):
279+
def expand_like(arrays, new_seq_length, padding_index=-100):
280280
""" Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
281-
if isinstance(arrays, (list, tuple)):
282-
return type(arrays)(nested_expand_like(x, new_seq_length, padding_index=padding_index) for x in arrays)
283-
284281
result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
285282
result[:, : arrays.shape[1]] = arrays
286283
return result
@@ -293,13 +290,6 @@ def nested_truncate(tensors, limit):
293290
return tensors[:limit]
294291

295292

296-
def _get_first_shape(arrays):
297-
"""Return the shape of the first array found in the nested struct `arrays`."""
298-
if isinstance(arrays, (list, tuple)):
299-
return _get_first_shape(arrays[0])
300-
return arrays.shape
301-
302-
303293
class DistributedTensorGatherer:
304294
"""
305295
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
@@ -367,21 +357,15 @@ def add_arrays(self, arrays):
367357
if self._storage is None:
368358
self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index)
369359
self._offsets = list(range(0, self.total_samples, self.process_length))
370-
else:
371-
storage_shape = _get_first_shape(self._storage)
372-
arrays_shape = _get_first_shape(arrays)
373-
if len(storage_shape) > 1 and storage_shape[1] < arrays_shape[1]:
374-
# If we get new arrays that are too big too fit, we expand the shape fo the storage
375-
self._storage = nested_expand_like(self._storage, arrays_shape[1], padding_index=self.padding_index)
376-
slice_len = self._nested_set_tensors(self._storage, arrays)
360+
361+
slice_len, self._storage = self._nested_set_tensors(self._storage, arrays)
377362
for i in range(self.world_size):
378363
self._offsets[i] += slice_len
379364

380365
def _nested_set_tensors(self, storage, arrays):
381366
if isinstance(arrays, (list, tuple)):
382-
for x, y in zip(storage, arrays):
383-
slice_len = self._nested_set_tensors(x, y)
384-
return slice_len
367+
result = [self._nested_set_tensors(x, y) for x, y in zip(storage, arrays)]
368+
return result[0][0], type(arrays)(r[1] for r in result)
385369
assert (
386370
arrays.shape[0] % self.world_size == 0
387371
), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
@@ -391,10 +375,13 @@ def _nested_set_tensors(self, storage, arrays):
391375
if len(arrays.shape) == 1:
392376
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
393377
else:
378+
# Expand the array on the fly if needed.
379+
if len(storage.shape) > 1 and storage.shape[1] < arrays.shape[1]:
380+
storage = expand_like(storage, arrays.shape[1], padding_index=self.padding_index)
394381
storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
395382
i * slice_len : (i + 1) * slice_len
396383
]
397-
return slice_len
384+
return slice_len, storage
398385

399386
def finalize(self):
400387
"""

tests/test_trainer_utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,49 @@ def test_distributed_tensor_gatherer(self):
8282
self.assertTrue(np.array_equal(result[1][0], predictions))
8383
self.assertTrue(np.array_equal(result[1][1], predictions))
8484

85+
def test_distributed_tensor_gatherer_different_shapes(self):
86+
# Simulate a result with a dataset of size 21, 4 processes and chunks of lengths 2, 3, 1
87+
world_size = 4
88+
num_samples = 21
89+
input_indices = [
90+
[0, 1, 6, 7, 12, 13, 18, 19],
91+
[2, 3, 4, 8, 9, 10, 14, 15, 16, 20, 0, 1],
92+
[5, 11, 17, 2],
93+
]
94+
sequence_lengths = [8, 10, 13]
95+
96+
predictions = np.random.normal(size=(num_samples, 13))
97+
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
98+
for indices, seq_length in zip(input_indices, sequence_lengths):
99+
gatherer.add_arrays(predictions[indices, :seq_length])
100+
result = gatherer.finalize()
101+
102+
# Remove the extra samples added at the end for a round multiple of num processes.
103+
actual_indices = [input_indices[0], input_indices[1][:-2], input_indices[2][:-1]]
104+
for indices, seq_length in zip(actual_indices, sequence_lengths):
105+
self.assertTrue(np.array_equal(result[indices, :seq_length], predictions[indices, :seq_length]))
106+
107+
# With nested tensors
108+
predictions = np.random.normal(size=(num_samples, 13))
109+
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
110+
for indices, seq_length in zip(input_indices, sequence_lengths):
111+
gatherer.add_arrays([predictions[indices, :seq_length], predictions[indices]])
112+
result = gatherer.finalize()
113+
114+
for indices, seq_length in zip(actual_indices, sequence_lengths):
115+
self.assertTrue(np.array_equal(result[0][indices, :seq_length], predictions[indices, :seq_length]))
116+
self.assertTrue(np.array_equal(result[1], predictions))
117+
118+
# Check if works if varying seq_length is second
119+
gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples)
120+
for indices, seq_length in zip(input_indices, sequence_lengths):
121+
gatherer.add_arrays([predictions[indices], predictions[indices, :seq_length]])
122+
result = gatherer.finalize()
123+
124+
self.assertTrue(np.array_equal(result[0], predictions))
125+
for indices, seq_length in zip(actual_indices, sequence_lengths):
126+
self.assertTrue(np.array_equal(result[1][indices, :seq_length], predictions[indices, :seq_length]))
127+
85128
def test_label_smoothing(self):
86129
epsilon = 0.1
87130
num_labels = 12

0 commit comments

Comments
 (0)