diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index 4a438e127e..fe953390e8 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -54,6 +54,7 @@ policy: tensor_parallel_size: 1 context_parallel_size: 1 custom_parallel_plan: null + clear_cache_every_n_steps: null dynamic_batching: enabled: false diff --git a/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml b/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml index 084ea843f2..e20d7970bb 100644 --- a/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml +++ b/examples/configs/recipes/llm/dpo-mistral-nemo-instruct-2407-1n8g-fsdp2tp8-actckpt-long.yaml @@ -49,6 +49,7 @@ policy: tensor_parallel_size: 8 context_parallel_size: 1 custom_parallel_plan: null + clear_cache_every_n_steps: 1 env_vars: PYTORCH_CUDA_ALLOC_CONF: "max_split_size_mb:64" diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 3f2fcfe877..7e38938db9 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -27,6 +27,7 @@ class DTensorConfig(TypedDict): tensor_parallel_size: NotRequired[int] context_parallel_size: NotRequired[int] custom_parallel_plan: NotRequired[str] + clear_cache_every_n_steps: NotRequired[int] class SequencePackingConfig(TypedDict): diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 7b2f0de271..14d6a118e6 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -16,6 +16,7 @@ import gc import itertools import os +import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import Any, Generator, Iterable, Optional, Set, Union, cast @@ -629,10 +630,20 @@ def train( mb_iterator = batch.make_microbatch_iterator(mbs) iterator_len = batch.size // mbs + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "empty_cache_every_n_steps" + ) + if empty_cache_steps: + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead." + ) + for mb_idx, mb in enumerate( itertools.chain(mb_iterator, dummy_iterator) ): - torch.cuda.empty_cache() + # Conditioanlly empty cache when sensitive to fragmentation + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() with torch.autocast(device_type="cuda", dtype=self.dtype): if self.enable_seq_packing: diff --git a/nemo_rl/models/policy/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/dtensor_policy_worker_v2.py index 469db0cd3e..47393df390 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/dtensor_policy_worker_v2.py @@ -15,6 +15,7 @@ import gc import itertools import os +import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext from typing import Any, Generator, Iterable, Optional, cast @@ -573,10 +574,20 @@ def train( mb_iterator = batch.make_microbatch_iterator(mbs) iterator_len = batch.size // mbs + empty_cache_steps = self.cfg.get("dtensor_cfg", {}).get( + "empty_cache_every_n_steps" + ) + if empty_cache_steps: + warnings.warn( + f"Emptying cache every {empty_cache_steps} microbatches, doing so unnnecessarily would incur a large performance overhead." + ) + for mb_idx, mb in enumerate( itertools.chain(mb_iterator, dummy_iterator) ): - torch.cuda.empty_cache() + # Conditioanlly empty cache when sensitive to fragmentation + if empty_cache_steps and mb_idx % empty_cache_steps == 0: + torch.cuda.empty_cache() with torch.autocast(device_type="cuda", dtype=self.dtype): if self.enable_seq_packing: