@@ -276,11 +276,8 @@ def nested_new_like(arrays, num_samples, padding_index=-100):
276276 return np .full_like (arrays , padding_index , shape = (num_samples , * arrays .shape [1 :]))
277277
278278
279- def nested_expand_like (arrays , new_seq_length , padding_index = - 100 ):
279+ def expand_like (arrays , new_seq_length , padding_index = - 100 ):
280280 """ Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
281- if isinstance (arrays , (list , tuple )):
282- return type (arrays )(nested_expand_like (x , new_seq_length , padding_index = padding_index ) for x in arrays )
283-
284281 result = np .full_like (arrays , padding_index , shape = (arrays .shape [0 ], new_seq_length ) + arrays .shape [2 :])
285282 result [:, : arrays .shape [1 ]] = arrays
286283 return result
@@ -293,13 +290,6 @@ def nested_truncate(tensors, limit):
293290 return tensors [:limit ]
294291
295292
296- def _get_first_shape (arrays ):
297- """Return the shape of the first array found in the nested struct `arrays`."""
298- if isinstance (arrays , (list , tuple )):
299- return _get_first_shape (arrays [0 ])
300- return arrays .shape
301-
302-
303293class DistributedTensorGatherer :
304294 """
305295 A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
@@ -367,21 +357,15 @@ def add_arrays(self, arrays):
367357 if self ._storage is None :
368358 self ._storage = nested_new_like (arrays , self .total_samples , padding_index = self .padding_index )
369359 self ._offsets = list (range (0 , self .total_samples , self .process_length ))
370- else :
371- storage_shape = _get_first_shape (self ._storage )
372- arrays_shape = _get_first_shape (arrays )
373- if len (storage_shape ) > 1 and storage_shape [1 ] < arrays_shape [1 ]:
374- # If we get new arrays that are too big too fit, we expand the shape fo the storage
375- self ._storage = nested_expand_like (self ._storage , arrays_shape [1 ], padding_index = self .padding_index )
376- slice_len = self ._nested_set_tensors (self ._storage , arrays )
360+
361+ slice_len , self ._storage = self ._nested_set_tensors (self ._storage , arrays )
377362 for i in range (self .world_size ):
378363 self ._offsets [i ] += slice_len
379364
380365 def _nested_set_tensors (self , storage , arrays ):
381366 if isinstance (arrays , (list , tuple )):
382- for x , y in zip (storage , arrays ):
383- slice_len = self ._nested_set_tensors (x , y )
384- return slice_len
367+ result = [self ._nested_set_tensors (x , y ) for x , y in zip (storage , arrays )]
368+ return result [0 ][0 ], type (arrays )(r [1 ] for r in result )
385369 assert (
386370 arrays .shape [0 ] % self .world_size == 0
387371 ), f"Arrays passed should all have a first dimension multiple of { self .world_size } , found { arrays .shape [0 ]} ."
@@ -391,10 +375,13 @@ def _nested_set_tensors(self, storage, arrays):
391375 if len (arrays .shape ) == 1 :
392376 storage [self ._offsets [i ] : self ._offsets [i ] + slice_len ] = arrays [i * slice_len : (i + 1 ) * slice_len ]
393377 else :
378+ # Expand the array on the fly if needed.
379+ if len (storage .shape ) > 1 and storage .shape [1 ] < arrays .shape [1 ]:
380+ storage = expand_like (storage , arrays .shape [1 ], padding_index = self .padding_index )
394381 storage [self ._offsets [i ] : self ._offsets [i ] + slice_len , : arrays .shape [1 ]] = arrays [
395382 i * slice_len : (i + 1 ) * slice_len
396383 ]
397- return slice_len
384+ return slice_len , storage
398385
399386 def finalize (self ):
400387 """
0 commit comments