Skip to content

Commit

Permalink
Manual garbage collection with an interval (#6469) (#6482)
Browse files Browse the repository at this point in the history
* Manual garbage collection with an interval



* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use trainer.global_step for tracking the interval of GC



---------

Signed-off-by: Sangkug Lym <[email protected]>
Co-authored-by: Sangkug Lym <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
4 people authored and yaoyu-33 committed May 26, 2023
1 parent cc258bb commit 13d0895
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,7 @@ model:
warmup_steps: 500
constant_steps: 50000
min_lr: 2e-5

gc_interval: 0
# Interval of the host memory garbage collection. When it is zero, collectiion relies on the automatic garbage collector.
# If an interger value larger than zero is set, collection is done manually by the batch step interval of `gc_interval`.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import os
import re
from typing import Any, Dict, Optional, Union
Expand Down Expand Up @@ -148,6 +149,13 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
"default_on_epoch": False,
}

self.gc_interval = cfg.get('gc_interval', 0)
assert self.gc_interval >= 0, "gc_interval should be an integer value larger than or equal to 0."
# If gc_interval > 0, memory garbage collection is manually controlled.
# The automatic garbage collector sould be disabled before training starts.
if self.gc_interval > 0:
gc.disable()

def _enable_nvidia_optimizations(self):
"These optimizations are present in NVIDIA NGC PyTorch Containers"

Expand Down Expand Up @@ -351,6 +359,9 @@ def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unus
# accumulated gradient updates.
grad_scaler.optimizer_update_skipped = None

if self.gc_interval > 0 and (self.trainer.global_step % self.gc_interval == 0):
gc.collect()

def setup_optimization(
self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None,
):
Expand Down

0 comments on commit 13d0895

Please sign in to comment.