Skip to content

Commit

Permalink
Distinguish between list of data iterators and data iterator that is …
Browse files Browse the repository at this point in the history
…a list

Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 committed May 8, 2023
1 parent b40342b commit e1b455f
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

try:
from megatron.core import parallel_state
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.pipeline_parallel.schedules import DataIteratorList, get_forward_backward_func

HAVE_MEGATRON_CORE = True

Expand Down Expand Up @@ -551,7 +551,7 @@ def _make_data_iterator_list(self, data_iterator: Iterator) -> List[Iterator]:
"""

if not isinstance(self.model, list) or len(self.model) == 1:
return [data_iterator]
return DataIteratorList([data_iterator])

class CachingIterator:
"""Iterator wrapper that caches values"""
Expand Down Expand Up @@ -589,10 +589,10 @@ def __next__(self):
return val

# Make list of iterator wrappers
data_iterator_list = [CachingIterator(data_iterator)]
while len(data_iterator_list) < len(self.model):
data_iterator_list.append(data_iterator_list[0].make_proxy())
return data_iterator_list
iters = [CachingIterator(data_iterator)]
while len(iters) < len(self.model):
iters.append(iters[0].make_proxy())
return DataIteratorList(iters)

def get_forward_output_and_loss_func(self, validation_step=False):
def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None):
Expand Down

0 comments on commit e1b455f

Please sign in to comment.