Skip to content

Commit

Permalink
prefetch num microbatches
Browse files Browse the repository at this point in the history
Signed-off-by: eharper <[email protected]>
  • Loading branch information
ericharper committed Aug 8, 2023
1 parent bcc6072 commit aa5b5fb
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import gc
import itertools
import os
import re
from dataclasses import fields
Expand Down Expand Up @@ -783,3 +784,20 @@ def build_model_parallel_config(self):
)

return model_parallel_config

def _prefetch(self, iterator):
"""Checks if the iterator still has elements to return.
Used in models using dataloader_iter to prefetch the next batch before fwd_bwd func
is called to avoid PP rank 2 from wait indefinitely to get outpits from PP 1
"""
elements = []
num_microbatches = get_num_microbatches()
for _ in range(num_microbatches):
try:
element = next(iterator)
elements.append(element)
except StopIteration:
return iterator, True

# return a new iterator with the prefetched element reinserted at the front
return itertools.chain(elements, iterator), False
13 changes: 0 additions & 13 deletions nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import copy
import hashlib
import itertools
import json
import os
from typing import Any, Mapping, Optional
Expand Down Expand Up @@ -135,18 +134,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None, no_lm_init=False):
# register encoder config
self.register_bert_model()

def _prefetch(self, iterator):
"""Checks if the iterator still has elements to return.
Used in models using dataloader_iter to prefetch the next batch before fwd_bwd func
is called to avoid PP rank 2 from wait indefinitely to get outpits from PP 1
"""
try:
element = next(iterator)
except StopIteration:
return iterator, True

# return a new iterator with the prefetched element reinserted at the front
return itertools.chain([element], iterator), False

def register_artifact(
self, config_path: str, src: str, verify_src_exists: bool = False,
Expand Down

0 comments on commit aa5b5fb

Please sign in to comment.