Skip to content

Commit 6060d9c

Browse files
titu1994ericharper
authored andcommitted
Fix documentation for Numba (NVIDIA#7065)
* Fix documentation for Numba Signed-off-by: smajumdar <[email protected]> * Update force float32 flag dynamically Signed-off-by: smajumdar <[email protected]> * Update force float32 flag dynamically Signed-off-by: smajumdar <[email protected]> * Fix nemo version Signed-off-by: smajumdar <[email protected]> --------- Signed-off-by: smajumdar <[email protected]> Co-authored-by: Eric Harper <[email protected]> Signed-off-by: zhehuaichen <[email protected]>
1 parent cb74f7c commit 6060d9c

File tree

6 files changed

+18
-13
lines changed

6 files changed

+18
-13
lines changed

README.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ Built for speed, NeMo can utilize NVIDIA's Tensor Cores and scale out training t
132132
Requirements
133133
------------
134134

135-
1) Python 3.8 or above
136-
2) Pytorch 1.10.0 or above
135+
1) Python 3.9 or above
136+
2) Pytorch 1.13.1 or above
137137
3) NVIDIA GPU for training
138138

139139
Documentation

docs/source/nlp/api.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ Datasets
124124
.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset
125125
:show-inheritance:
126126

127-
.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTChatDataset
127+
.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset.GPTSFTChatDataset
128128
:show-inheritance:
129129

130130
.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.retro_dataset.RETRODataset

docs/source/starthere/intro.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ Prerequisites
3434

3535
Before you begin using NeMo, it's assumed you meet the following prerequisites.
3636

37-
#. You have Python version 3.6, 3.7 or 3.8.
37+
#. You have Python version 3.9, 3.10.
3838

39-
#. You have Pytorch version 1.8.1.
39+
#. You have Pytorch version 1.13.1 or 2.0+.
4040

4141
#. You have access to an NVIDIA GPU for training.
4242

nemo/collections/asr/losses/rnnt.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class RNNTLossConfig:
9999
min_version='0.53.0',
100100
is_available=NUMBA_RNNT_AVAILABLE,
101101
installation_msg=NUMBA_INSTALLATION_MESSAGE,
102-
force_float32=not numba_utils.NUMBA_FP16_SUPPORTED,
102+
force_float32=False, # This is only temporarily false, will be dynamically updated during resolution
103103
),
104104
"pytorch": RNNTLossConfig(
105105
loss_name="pytorch",
@@ -258,6 +258,9 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
258258
_warn_unused_additional_kwargs(loss_name, loss_kwargs)
259259

260260
elif loss_name == 'warprnnt_numba':
261+
# Update loss config's forced float32 flag if set to None
262+
loss_config.force_float32 = not numba_utils.is_numba_cuda_fp16_supported()
263+
261264
fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0)
262265
clamp = loss_kwargs.pop('clamp', -1.0)
263266
loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp)
@@ -444,7 +447,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
444447
max_targets_len = target_lengths.max()
445448

446449
# Force cast joint to float32
447-
if not self._force_float32 and numba_utils.NUMBA_FP16_SUPPORTED:
450+
if not self._force_float32 and numba_utils.is_numba_cuda_fp16_supported():
448451
# Execute the kernel in fp16
449452
pass
450453
elif self._force_float32 and log_probs.dtype != torch.float32:

nemo/core/utils/numba_utils.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@
2929
__NUMBA_MINIMUM_VERSION__ = os.environ.get("NEMO_NUMBA_MINVER", __NUMBA_DEFAULT_MINIMUM_VERSION__)
3030

3131
__NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__ = "0.57.0"
32-
NUMBA_FP16_SUPPORTED = model_utils.check_lib_version(
33-
'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge
34-
)[0]
3532

3633

3734
NUMBA_INSTALLATION_MESSAGE = (
@@ -171,12 +168,16 @@ def is_numba_cuda_fp16_supported(return_reason: bool = False) -> Union[bool, Tup
171168
use_nvidia_binding = False
172169
reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is not available or has not set to `1`."
173170

174-
if NUMBA_FP16_SUPPORTED:
171+
numba_fp16_version_correct = model_utils.check_lib_version(
172+
'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge
173+
)[0]
174+
175+
if numba_fp16_version_correct:
175176
reason += f"Numba CUDA FP16 is supported in installed numba version."
176177
else:
177178
reason += f"Numba CUDA FP16 is not supported in installed numba version."
178179

179-
result = use_nvidia_binding and NUMBA_FP16_SUPPORTED
180+
result = use_nvidia_binding and numba_fp16_version_correct
180181

181182
if return_reason:
182183
return result, reason

nemo/utils/model_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import importlib
1617
import os
1718
from dataclasses import dataclass, is_dataclass
1819
from enum import Enum
@@ -554,7 +555,7 @@ def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Op
554555
if '.' in lib_name:
555556
mod = import_class_by_path(lib_name)
556557
else:
557-
mod = __import__(lib_name)
558+
mod = importlib.import_module(lib_name)
558559

559560
if hasattr(mod, '__version__'):
560561
lib_ver = version.Version(mod.__version__)

0 commit comments

Comments
 (0)