Skip to content

Commit b8e001a

Browse files
authored
cp: Add import guards for mcore lightning module (#14970) into r2.5.0 (#14982)
* Add import guards for mcore lightning module (#14970) * Add import guards for mcore lightning module Signed-off-by: Charlie Truong <[email protected]> * Apply isort and black reformatting Signed-off-by: chtruong814 <[email protected]> * Add import guard for Apex Signed-off-by: Charlie Truong <[email protected]> * Add asr install checker Signed-off-by: Charlie Truong <[email protected]> --------- Signed-off-by: Charlie Truong <[email protected]> Signed-off-by: chtruong814 <[email protected]> Co-authored-by: chtruong814 <[email protected]> Signed-off-by: Charlie Truong <[email protected]> * Apply isort and black reformatting Signed-off-by: chtruong814 <[email protected]> --------- Signed-off-by: Charlie Truong <[email protected]> Signed-off-by: chtruong814 <[email protected]> Co-authored-by: chtruong814 <[email protected]>
1 parent 409e409 commit b8e001a

File tree

11 files changed

+230
-41
lines changed

11 files changed

+230
-41
lines changed

.github/workflows/install-test.yml

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,54 @@ jobs:
130130
python tests/core_ptl/check_imports.py --domain "$collection"
131131
done
132132
133+
test-asr-install-linux-amd:
134+
name: ubuntu-22.04-amd-py${{ matrix.python }}-asr
135+
runs-on: ubuntu-22.04
136+
strategy:
137+
fail-fast: false
138+
matrix:
139+
python: ["3.10", "3.11", "3.12"]
140+
steps:
141+
- name: Checkout repo
142+
uses: actions/checkout@v2
143+
144+
- name: Check disk space before cleanup
145+
run: df -h
146+
147+
- name: Free up disk space
148+
run: |
149+
# Remove unnecessary packages and files on Ubuntu
150+
sudo apt-get clean
151+
sudo rm -rf /usr/local/lib/android || true
152+
sudo rm -rf /opt/ghc || true
153+
sudo rm -rf /usr/local/.ghcup || true
154+
sudo rm -rf /usr/share/dotnet || true
155+
sudo rm -rf /opt/az || true
156+
# Clear pip and npm caches
157+
pip cache purge || true
158+
sudo npm cache clean --force || true
159+
160+
- name: Check disk space after cleanup
161+
run: df -h
162+
163+
- name: Install Python
164+
uses: actions/setup-python@v5
165+
with:
166+
python-version: ${{ matrix.python }}
167+
168+
- name: Install NeMo
169+
run: |
170+
pip install --no-cache-dir --upgrade pip
171+
pip install --no-cache-dir ".[asr]"
172+
173+
- name: Check disk space after installation
174+
run: df -h
175+
176+
- name: Run import checks
177+
run: |
178+
# Run import checks
179+
python tests/core_ptl/check_imports.py --domain asr
180+
133181
test-installs-linux-arm:
134182
name: ubuntu-22.04-arm-py${{ matrix.python }}-${{ matrix.installer }}
135183
runs-on: ubuntu-22.04-arm
@@ -188,3 +236,51 @@ jobs:
188236
for collection in "asr" "tts" "lightning" "core"; do
189237
python tests/core_ptl/check_imports.py --domain "$collection"
190238
done
239+
240+
test-asr-installs-linux-arm:
241+
name: ubuntu-22.04-arm-py${{ matrix.python }}-asr
242+
runs-on: ubuntu-22.04-arm
243+
strategy:
244+
fail-fast: false
245+
matrix:
246+
python: ["3.10", "3.11", "3.12"]
247+
steps:
248+
- name: Checkout repo
249+
uses: actions/checkout@v2
250+
251+
- name: Check disk space before cleanup
252+
run: df -h
253+
254+
- name: Free up disk space
255+
run: |
256+
# Remove unnecessary packages and files on Ubuntu ARM
257+
sudo apt-get clean
258+
sudo rm -rf /usr/local/lib/android || true
259+
sudo rm -rf /opt/ghc || true
260+
sudo rm -rf /usr/local/.ghcup || true
261+
sudo rm -rf /usr/share/dotnet || true
262+
sudo rm -rf /opt/az || true
263+
# Clear pip and npm caches
264+
pip cache purge || true
265+
sudo npm cache clean --force || true
266+
267+
- name: Check disk space after cleanup
268+
run: df -h
269+
270+
- name: Install Python
271+
uses: actions/setup-python@v5
272+
with:
273+
python-version: ${{ matrix.python }}
274+
275+
- name: Install NeMo
276+
run: |
277+
pip install --no-cache-dir --upgrade pip
278+
pip install --no-cache-dir ".[asr]"
279+
280+
- name: Check disk space after installation
281+
run: df -h
282+
283+
- name: Run import checks
284+
run: |
285+
# Run import checks
286+
python tests/core_ptl/check_imports.py --domain asr

nemo/lightning/fabric/strategies.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,19 @@
4242
from lightning.pytorch.loops.fetchers import _DataFetcher
4343
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
4444
from lightning.pytorch.utilities.combined_loader import CombinedLoader
45-
from megatron.core.distributed import DistributedDataParallelConfig
46-
from megatron.core.optimizer import OptimizerConfig
45+
46+
try:
47+
from megatron.core.distributed import DistributedDataParallelConfig
48+
from megatron.core.optimizer import OptimizerConfig
49+
50+
HAVE_MEGATRON_CORE = True
51+
52+
except (ImportError, ModuleNotFoundError):
53+
54+
DistributedDataParallelConfig = object
55+
OptimizerConfig = object
56+
HAVE_MEGATRON_CORE = False
57+
4758
from torch import Tensor, nn
4859
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
4960
from torch.nn import Module

nemo/lightning/io/pl.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,26 @@
2222
from lightning.fabric.plugins import CheckpointIO
2323
from lightning.fabric.utilities.cloud_io import get_filesystem
2424
from lightning.fabric.utilities.types import _PATH
25-
from megatron.core.dist_checkpointing.serialization import (
26-
get_default_load_sharded_strategy,
27-
get_default_save_sharded_strategy,
28-
)
29-
from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy
30-
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
31-
FullyParallelLoadStrategyWrapper,
32-
FullyParallelSaveStrategyWrapper,
33-
)
34-
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
35-
from megatron.core.parallel_state import get_data_parallel_group
25+
26+
try:
27+
from megatron.core.dist_checkpointing.serialization import (
28+
get_default_load_sharded_strategy,
29+
get_default_save_sharded_strategy,
30+
)
31+
from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy
32+
from megatron.core.dist_checkpointing.strategies.fully_parallel import (
33+
FullyParallelLoadStrategyWrapper,
34+
FullyParallelSaveStrategyWrapper,
35+
)
36+
from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
37+
from megatron.core.parallel_state import get_data_parallel_group
38+
39+
HAVE_MEGATRON_CORE = True
40+
41+
except (ImportError, ModuleNotFoundError):
42+
43+
HAVE_MEGATRON_CORE = False
44+
3645
from torch import nn
3746
from typing_extensions import Self, override
3847

nemo/lightning/megatron_init.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,17 @@
6060

6161
except (ImportError, ModuleNotFoundError):
6262
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
63-
from apex.transformer.microbatches import ConstantNumMicroBatches as ConstantNumMicroBatchesCalculator
64-
from apex.transformer.pipeline_parallel.utils import (
65-
get_current_global_batch_size,
66-
get_micro_batch_size,
67-
get_num_microbatches,
68-
)
69-
from apex.transformer.pipeline_parallel.utils import (
70-
setup_microbatch_calculator as init_num_microbatches_calculator,
71-
)
63+
64+
if HAVE_APEX:
65+
from apex.transformer.microbatches import ConstantNumMicroBatches as ConstantNumMicroBatchesCalculator
66+
from apex.transformer.pipeline_parallel.utils import (
67+
get_current_global_batch_size,
68+
get_micro_batch_size,
69+
get_num_microbatches,
70+
)
71+
from apex.transformer.pipeline_parallel.utils import (
72+
setup_microbatch_calculator as init_num_microbatches_calculator,
73+
)
7274

7375
MCORE_MB_CALCULATOR = False
7476

nemo/lightning/megatron_parallel.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,23 @@
5050
import torch.distributed
5151
from lightning.pytorch.trainer.states import TrainerFn
5252
from lightning.pytorch.utilities import move_data_to_device
53-
from megatron.core import parallel_state
54-
from megatron.core.distributed import DistributedDataParallel as McoreDDP
55-
from megatron.core.distributed import DistributedDataParallelConfig
56-
from megatron.core.optimizer import OptimizerConfig
57-
from megatron.core.transformer.transformer_config import TransformerConfig
53+
54+
55+
try:
56+
from megatron.core import parallel_state
57+
from megatron.core.distributed import DistributedDataParallel as McoreDDP
58+
from megatron.core.distributed import DistributedDataParallelConfig
59+
from megatron.core.optimizer import OptimizerConfig
60+
from megatron.core.transformer.transformer_config import TransformerConfig
61+
62+
HAVE_MEGATRON_CORE = True
63+
except (ImportError, ModuleNotFoundError):
64+
65+
McoreDDP = object
66+
DistributedDataParallelConfig = object
67+
TransformerConfig = object
68+
HAVE_MEGATRON_CORE = False
69+
5870
from torch import Tensor, nn
5971
from typing_extensions import override
6072

nemo/lightning/pytorch/callbacks/ddp_parity_checker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,15 @@
1616

1717
import torch
1818
from lightning.pytorch.callbacks.callback import Callback
19-
from megatron.core.utils import check_param_hashes_across_dp_replicas
19+
20+
try:
21+
from megatron.core.utils import check_param_hashes_across_dp_replicas
22+
23+
HAVE_MEGATRON_CORE = True
24+
25+
except (ImportError, ModuleNotFoundError):
26+
27+
HAVE_MEGATRON_CORE = False
2028

2129
from nemo.lightning import io
2230
from nemo.utils import logging

nemo/lightning/pytorch/callbacks/progress_printer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,16 @@
1717

1818
from lightning.pytorch.callbacks.progress import ProgressBar
1919
from lightning.pytorch.utilities.types import STEP_OUTPUT
20-
from megatron.core.num_microbatches_calculator import get_num_microbatches
20+
21+
try:
22+
from megatron.core.num_microbatches_calculator import get_num_microbatches
23+
24+
HAVE_MEGATRON_CORE = True
25+
26+
except (ImportError, ModuleNotFoundError):
27+
28+
HAVE_MEGATRON_CORE = False
29+
2130
from typing_extensions import override
2231

2332

nemo/lightning/pytorch/optim/megatron.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,19 @@
1515
from typing import Callable, List, Optional
1616

1717
import lightning.pytorch as pl
18-
from megatron.core.distributed import finalize_model_grads
19-
from megatron.core.optimizer import OptimizerConfig
20-
from megatron.core.utils import get_model_config
18+
19+
try:
20+
from megatron.core.distributed import finalize_model_grads
21+
from megatron.core.optimizer import OptimizerConfig
22+
from megatron.core.utils import get_model_config
23+
24+
HAVE_MEGATRON_CORE = True
25+
26+
except (ImportError, ModuleNotFoundError):
27+
28+
OptimizerConfig = object
29+
HAVE_MEGATRON_CORE = False
30+
2131
from torch.optim import Optimizer
2232

2333
from nemo.lightning._strategy_lib import setup_megatron_optimizer

nemo/lightning/pytorch/strategies/fsdp_strategy.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,17 @@
2929
from lightning.pytorch.strategies.fsdp import FSDPStrategy as PLFSDPStrategy
3030
from lightning.pytorch.trainer.states import TrainerFn
3131
from lightning.pytorch.utilities.types import STEP_OUTPUT
32-
from megatron.core.transformer.transformer_layer import TransformerLayer
32+
33+
try:
34+
from megatron.core.transformer.transformer_layer import TransformerLayer
35+
36+
HAVE_MEGATRON_CORE = True
37+
38+
except (ImportError, ModuleNotFoundError):
39+
40+
TransformerLayer = object
41+
HAVE_MEGATRON_CORE = False
42+
3343
from torch.distributed.checkpoint.state_dict import ( # get_state_dict,
3444
StateDictOptions,
3545
get_optimizer_state_dict,

nemo/lightning/pytorch/strategies/megatron_strategy.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,19 @@
5252
from lightning.pytorch.strategies.ddp import DDPStrategy
5353
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
5454
from lightning.pytorch.utilities.types import STEP_OUTPUT
55-
from megatron.core import Timers
56-
from megatron.core.dist_checkpointing.validation import StrictHandling
57-
from megatron.core.distributed import DistributedDataParallelConfig
58-
from megatron.core.optimizer import OptimizerConfig
55+
56+
try:
57+
from megatron.core import Timers
58+
from megatron.core.dist_checkpointing.validation import StrictHandling
59+
from megatron.core.distributed import DistributedDataParallelConfig
60+
from megatron.core.optimizer import OptimizerConfig
61+
62+
HAVE_MEGATRON_CORE = True
63+
except (ImportError, ModuleNotFoundError):
64+
65+
DistributedDataParallelConfig = object
66+
HAVE_MEGATRON_CORE = False
67+
5968
from torch import nn
6069
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
6170
from torch.distributed.checkpoint.utils import CheckpointException

0 commit comments

Comments
 (0)