diff --git a/nemo/collections/asr/losses/rnnt.py b/nemo/collections/asr/losses/rnnt.py index 10b85acb42ef..a884f7d3cc68 100644 --- a/nemo/collections/asr/losses/rnnt.py +++ b/nemo/collections/asr/losses/rnnt.py @@ -38,9 +38,10 @@ from nemo.collections.asr.losses.rnnt_pytorch import MultiblankRNNTLossPytorch, RNNTLossPytorch, TDTLossPytorch from nemo.core.classes import Loss, typecheck from nemo.core.neural_types import LabelsType, LengthsType, LogprobsType, LossType, NeuralType +from nemo.core.utils import numba_utils from nemo.core.utils.k2_utils import K2_INSTALLATION_MESSAGE from nemo.core.utils.numba_utils import NUMBA_INSTALLATION_MESSAGE -from nemo.utils import logging, model_utils +from nemo.utils import logging, logging_mode, model_utils try: import warprnnt_pytorch as warprnnt @@ -98,7 +99,7 @@ class RNNTLossConfig: min_version='0.53.0', is_available=NUMBA_RNNT_AVAILABLE, installation_msg=NUMBA_INSTALLATION_MESSAGE, - force_float32=True, + force_float32=not numba_utils.NUMBA_FP16_SUPPORTED, ), "pytorch": RNNTLossConfig( loss_name="pytorch", @@ -387,7 +388,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str = for the standard "blank" symbol. In particular, say V is the number of non-blank tokens in the vocabulary, then in the case of, standard RNNT: num_classes = V - multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before + multiblank RNNT: num_classes = V + number-big-blanks (since we store big-blanks before standard blank, and the standard blank is the last symbol in the vocab) TDT: num_classes = V. Note, V here does not include any of the "duration outputs". @@ -413,6 +414,7 @@ def __init__(self, num_classes, reduction: str = 'mean_batch', loss_name: str = self.reduction = reduction self._loss = resolve_rnnt_loss(loss_name, blank_idx=self._blank, loss_kwargs=loss_kwargs) self._force_float32 = RNNT_LOSS_RESOLVER[loss_name].force_float32 + self._fp16_compat_checked = False def reduce(self, losses, target_lengths): @@ -442,8 +444,22 @@ def forward(self, log_probs, targets, input_lengths, target_lengths): max_targets_len = target_lengths.max() # Force cast joint to float32 - # TODO: Remove once Numba supports FP16 - if self._force_float32 and log_probs.dtype != torch.float32: + if not self._force_float32 and numba_utils.NUMBA_FP16_SUPPORTED: + # Execute the kernel in fp16 + pass + elif self._force_float32 and log_probs.dtype != torch.float32: + # Log just once if fp16 tensor was passed and fp16 Numba CUDA loss could not be used. + if log_probs.dtype == torch.float16 and not self._fp16_compat_checked: + _, reason = numba_utils.is_numba_cuda_fp16_supported(return_reason=True) + logging.warning( + f"Provided RNNT Joint tensor is of dtype {log_probs.dtype}, but RNNT loss could not be calculated " + f"in fp16 due to following reason stated below. Loss will be calculated in fp32. \n\n" + f"{reason}", + mode=logging_mode.ONCE, + ) + self._fp16_compat_checked = True + + # Upcast the activation tensor and compute loss and grads in fp32 logits_orig = log_probs log_probs = log_probs.float() del logits_orig # save memory *before* computing the loss diff --git a/nemo/collections/asr/losses/rnnt_pytorch.py b/nemo/collections/asr/losses/rnnt_pytorch.py index bc6e5a25a3b2..c8eee90a2eb5 100644 --- a/nemo/collections/asr/losses/rnnt_pytorch.py +++ b/nemo/collections/asr/losses/rnnt_pytorch.py @@ -47,7 +47,12 @@ def __init__(self, blank, reduction): self.reduction = reduction def forward(self, acts, labels, act_lens, label_lens): + # CPU patch for FP16 + if not acts.is_cuda and acts.dtype == torch.float16: + acts = acts.float() + acts = torch.log_softmax(acts, -1) + forward_logprob = self.compute_forward_prob(acts, labels, act_lens, label_lens) losses = -forward_logprob if self.reduction == 'mean_batch': diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py index 118ee88acbfe..046aea425e20 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt.py @@ -186,7 +186,7 @@ def rnnt_loss_gpu( # Select GPU index cuda.select_device(acts.device.index) - gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=acts.dtype, requires_grad=False) + gpu_workspace = torch.zeros(gpu_size, device=acts.device, dtype=torch.float32, requires_grad=False) ### VIEW TENSORS AS VECTORS FOR POINTER INDEXING ### acts, acts_shape = rnnt_helper.flatten_tensor(acts) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py index eaa6d332a0fc..58508970aa83 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_numpy.py @@ -344,10 +344,15 @@ def forward(self, acts, labels, act_lens, label_lens): _assert_no_grad(label_lens) certify_inputs(acts, labels, act_lens, label_lens) + # CPU Patch for fp16 - force cast to fp32 + if not acts.is_cuda and acts.dtype == torch.float16: + acts = acts.float() + if self.clamp > 0.0: acts = LogSoftmaxGradModification.apply(acts, self.clamp) acts = torch.nn.functional.log_softmax(acts, -1) + return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py index 2ffe08be361e..5960d5ab6b18 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/rnnt_pytorch.py @@ -57,7 +57,7 @@ def forward(ctx, acts, labels, act_lens, label_lens, blank, reduction, fastemit_ loss_func = rnnt.rnnt_loss_gpu if is_cuda else rnnt.rnnt_loss_cpu grads = torch.zeros_like(acts) if acts.requires_grad else None minibatch_size = acts.size(0) - costs = torch.zeros(minibatch_size, device=acts.device, dtype=acts.dtype) + costs = torch.zeros(minibatch_size, device=acts.device, dtype=torch.float32) loss_func( acts, @@ -119,7 +119,6 @@ def forward( label_lens: Tensor of (batch) containing label length of each example fastemit_lambda: Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization. - durations: list of durations for TDT model, must include 0 and 1, e.g. [0, 1, 2, 3, 4]. sigma: hyper-parameter for logit under-normalization method for training @@ -417,6 +416,10 @@ def forward(self, acts, labels, act_lens, label_lens): label_lens: Tensor of (batch) containing label length of each example """ if not acts.is_cuda: + # Force FP32 until log_softmax() is implemented for fp16 on CPU + if acts.dtype == torch.float16: + acts = acts.float() + # Since CPU requires log_softmax to be computed explicitly, we need to perform grad clipping # *after* we have obtained the gradients of loss(logsoftmax()). # This is highly wasteful since it requires a copy of the entire joint tensor which is expensive. diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py index 1528606716e1..3feb7b513a50 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py @@ -231,8 +231,8 @@ def cost_and_grad_kernel( ) # Scale llForward by FastEmit lambda - llForward *= 1.0 + self.fastemit_lambda_ - llBackward *= 1.0 + self.fastemit_lambda_ + llForward += llForward * self.fastemit_lambda_ + llBackward += llBackward * self.fastemit_lambda_ diff = (llForward - llBackward).abs() if diff > 0.1: @@ -300,6 +300,10 @@ def compute_betas_and_grads( Returns: Loglikelihood of the forward variable and inplace updates the grad tensor. """ + # Patch for CPU + fp16 + if log_probs.dtype == torch.float16 and not log_probs.is_cuda: + log_probs = log_probs.float() + idx = CpuRNNT_index(U, self.maxU_, self.minibatch_, self.alphabet_size_, self.batch_first) betas[idx(T - 1, U - 1)] = log_probs[idx(T - 1, U - 1) * 2] diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/rnnt_helper.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/rnnt_helper.py index b579b7315ef2..6ca7cd237264 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/rnnt_helper.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/rnnt_helper.py @@ -30,6 +30,7 @@ import math from typing import Optional, Tuple +import numba import torch from numba import cuda @@ -112,7 +113,7 @@ def compute_costs_data(source: torch.Tensor, dest: torch.Tensor, fastemit_lambda if idx < length: copy_data_1d(source, dest, idx) dest[idx] *= -1.0 - dest[idx] *= 1.0 + fastemit_lambda + dest[idx] *= numba.float32(1.0 + fastemit_lambda) def get_workspace_size( diff --git a/nemo/core/utils/numba_utils.py b/nemo/core/utils/numba_utils.py index 6e1a8cb247d6..04010a2f7db4 100644 --- a/nemo/core/utils/numba_utils.py +++ b/nemo/core/utils/numba_utils.py @@ -17,6 +17,8 @@ import operator import os +from typing import Tuple, Union + from nemo.utils import model_utils # Prevent Numba CUDA logs from showing at info level @@ -26,6 +28,11 @@ __NUMBA_DEFAULT_MINIMUM_VERSION__ = "0.53.0" __NUMBA_MINIMUM_VERSION__ = os.environ.get("NEMO_NUMBA_MINVER", __NUMBA_DEFAULT_MINIMUM_VERSION__) +__NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__ = "0.57.0" +NUMBA_FP16_SUPPORTED = model_utils.check_lib_version( + 'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge +)[0] + NUMBA_INSTALLATION_MESSAGE = ( "Could not import `numba`.\n" @@ -148,6 +155,35 @@ def numba_cuda_is_supported(min_version: str) -> bool: return False +def is_numba_cuda_fp16_supported(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Utility method that returns a bool, stating if FP16 is supported for numba cuda kernels or not. + + Returns: + bool, whether Numba CUDA will support fp16 or not. + """ + reason = "" + use_nvidia_binding = os.environ.get('NUMBA_CUDA_USE_NVIDIA_BINDING', None) + if use_nvidia_binding is not None: + use_nvidia_binding = use_nvidia_binding.lower() == "1" + reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is available and set to `1`. " + else: + use_nvidia_binding = False + reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is not available or has not set to `1`." + + if NUMBA_FP16_SUPPORTED: + reason += f"Numba CUDA FP16 is supported in installed numba version." + else: + reason += f"Numba CUDA FP16 is not supported in installed numba version." + + result = use_nvidia_binding and NUMBA_FP16_SUPPORTED + + if return_reason: + return result, reason + else: + return result + + def skip_numba_cuda_test_if_unsupported(min_version: str): """ Helper method to skip pytest test case if numba cuda is not supported. diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index 3fbfcf6df54b..1a29a14f540d 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -34,9 +34,14 @@ DEVICES.append('cuda') +DTYPES = [np.float32] +if numba_utils.is_numba_cuda_fp16_supported(): + DTYPES.append(np.float16) + + def wrap_and_call(fn, acts, labels, device): if not torch.is_tensor(acts): - acts = torch.FloatTensor(acts) + acts = torch.tensor(acts) if 'cuda' in device: acts = acts.cuda() @@ -72,7 +77,8 @@ def wrap_and_call(fn, acts, labels, device): class TestRNNTLossPytorch: @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) - def test_case_small(self, device): + @pytest.mark.parametrize('dtype', DTYPES) + def test_case_small(self, device, dtype): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) @@ -83,9 +89,13 @@ def test_case_small(self, device): [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]], ] ] - ) + ).astype(dtype) labels = [[1, 2]] + cost_threshold = 1e-8 if dtype == np.float32 else 5e-4 + grad_threshold = 1e-8 if dtype == np.float32 else 1e-4 + rtol = 1e-5 if dtype == np.float32 else 1e-3 + fn_pt = RNNTLossNumba(blank=0, reduction='sum') pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) @@ -113,23 +123,28 @@ def test_case_small(self, device): ] ) - assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." - assert np.allclose(pt_grads, expected_grads), "small_test gradient mismatch." + assert np.allclose(pt_cost, expected_cost, atol=cost_threshold, rtol=1e-6), "small_test costs mismatch." + assert np.allclose(pt_grads, expected_grads, atol=grad_threshold, rtol=rtol), "small_test gradient mismatch." - assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_test costs mismatch." - assert np.allclose(pt_grads, np_grads), "small_test gradient mismatch." + assert np.allclose(pt_cost, np_cost, atol=cost_threshold, rtol=rtol), "small_test costs mismatch." + assert np.allclose(pt_grads, np_grads, atol=grad_threshold, rtol=rtol), "small_test gradient mismatch." - assert np.allclose(ag_cost, np_cost, rtol=1e-6), "small_test costs mismatch." - assert np.allclose(ag_grads, np_grads), "small_test gradient mismatch." + assert np.allclose(ag_cost, np_cost, atol=cost_threshold, rtol=rtol), "small_test costs mismatch." + assert np.allclose(ag_grads, np_grads, atol=cost_threshold, rtol=rtol), "small_test gradient mismatch." @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) - def test_case_small_random(self, device): + @pytest.mark.parametrize('dtype', DTYPES) + def test_case_small_random(self, device, dtype): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + cost_threshold = 1e-8 if dtype == np.float32 else 5e-4 + grad_threshold = 1e-8 if dtype == np.float32 else 1e-4 + rtol = 1e-5 if dtype == np.float32 else 1e-3 + rng = np.random.RandomState(0) - acts = rng.randn(1, 4, 3, 3) + acts = rng.randn(1, 4, 3, 3).astype(dtype) labels = [[1, 2]] fn_pt = RNNTLossNumba(blank=0, reduction='sum') @@ -141,16 +156,17 @@ def test_case_small_random(self, device): fn_ag = RNNTLossPytorch(blank=0, reduction='sum') # ag for automatic gradient computation ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) - assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_random_test costs mismatch." - assert np.allclose(pt_grads, np_grads), "small_random_test gradient mismatch." + assert np.allclose(pt_cost, np_cost, atol=cost_threshold, rtol=rtol), "small_random_test costs mismatch." + assert np.allclose(pt_grads, np_grads, atol=grad_threshold, rtol=rtol), "small_random_test gradient mismatch." - assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "small_random_test costs mismatch." - assert np.allclose(pt_grads, ag_grads), "small_random_test gradient mismatch." + assert np.allclose(pt_cost, ag_cost, atol=cost_threshold, rtol=rtol), "small_random_test costs mismatch." + assert np.allclose(pt_grads, ag_grads, atol=grad_threshold, rtol=rtol), "small_random_test gradient mismatch." @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) + @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('fastemit_lambda', [1.0, 0.01, 0.00001]) - def test_case_small_random_fastemit_reg(self, device, fastemit_lambda): + def test_case_small_random_fastemit_reg(self, device, dtype, fastemit_lambda): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) @@ -165,11 +181,12 @@ def test_case_small_random_fastemit_reg(self, device, fastemit_lambda): np_cost, np_grads = wrap_and_call(fn_np, acts, labels, device) assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_random_test costs mismatch." - assert np.allclose(pt_grads, np_grads, atol=1e-5, rtol=1e-5), "small_random_test gradient mismatch." + assert np.allclose(pt_grads, np_grads, rtol=1e-5), "small_random_test gradient mismatch." @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) - def test_case_big_tensor(self, device): + @pytest.mark.parametrize('dtype', DTYPES) + def test_case_big_tensor(self, device, dtype): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) @@ -269,9 +286,13 @@ def test_case_big_tensor(self, device): ], ] - activations = np.array(activations) + activations = np.array(activations).astype(dtype) labels = [[1, 2], [1, 1]] + cost_threshold = 1e-8 if dtype == np.float32 else 5e-4 + grad_threshold = 1e-8 if dtype == np.float32 else 1e-4 + rtol = 1e-3 if dtype == np.float32 else 0.1 + fn_pt = RNNTLossNumba(blank=0, reduction='sum') pt_costs, pt_grads = wrap_and_call(fn_pt, activations, labels, device) @@ -281,23 +302,30 @@ def test_case_big_tensor(self, device): fn_ag = RNNTLossPytorch(blank=0, reduction='sum') ag_costs, ag_grads = wrap_and_call(fn_ag, activations, labels, device) - assert np.allclose(pt_costs, sum(expected_costs)), "big_test average costs mismatch." - assert np.allclose(pt_grads, expected_grads, rtol=1e-3), "big_test grads for average cost mismatch." + assert np.allclose(pt_costs, sum(expected_costs), atol=cost_threshold), "big_test average costs mismatch." + assert np.allclose( + pt_grads, expected_grads, atol=grad_threshold, rtol=1e-3 + ), "big_test grads for average cost mismatch." - assert np.allclose(pt_costs, np_costs), "big_test average costs mismatch." - assert np.allclose(pt_grads, np_grads, rtol=1e-3), "big_test grads for average cost mismatch." + assert np.allclose(pt_costs, np_costs, atol=cost_threshold, rtol=rtol), "big_test average costs mismatch." + assert np.allclose( + pt_grads, np_grads, atol=grad_threshold, rtol=rtol + ), "big_test grads for average cost mismatch." - assert np.allclose(pt_costs, ag_costs), "big_test average costs mismatch." - assert np.allclose(pt_grads, ag_grads, rtol=1e-3), "big_test grads for average cost mismatch." + assert np.allclose(pt_costs, ag_costs, atol=cost_threshold, rtol=rtol), "big_test average costs mismatch." + assert np.allclose( + pt_grads, ag_grads, atol=grad_threshold, rtol=rtol + ), "big_test grads for average cost mismatch." @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) - def test_case_large_random(self, device): + @pytest.mark.parametrize('dtype', DTYPES) + def test_case_large_random(self, device, dtype): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) rng = np.random.RandomState(0) - acts = rng.randn(4, 8, 11, 5) + acts = rng.randn(4, 8, 11, 5).astype(dtype) labels = [ [1, 2, 4, 3, 2, 2, 1, 1, 1, 1], [3, 2, 2, 3, 4, 1, 1, 1, 1, 1], @@ -305,6 +333,10 @@ def test_case_large_random(self, device): [1, 1, 2, 1, 2, 3, 3, 1, 1, 1], ] + cost_threshold = 1e-8 if dtype == np.float32 else 5e-4 + grad_threshold = 1e-8 if dtype == np.float32 else 1e-4 + rtol = 1e-3 if dtype == np.float32 else 5e-2 + fn_pt = RNNTLossNumba(blank=0, reduction='sum') pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) @@ -314,14 +346,15 @@ def test_case_large_random(self, device): fn_ag = RNNTLossPytorch(blank=0, reduction='sum') ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) - assert np.allclose(pt_cost, np_cost, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." - assert np.allclose(ag_cost, np_cost, atol=1e-5, rtol=1e-3), "large_random_test costs mismatch." - assert np.allclose(pt_grads, np_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch." - assert np.allclose(ag_grads, np_grads, atol=1e-5, rtol=1e-3), "large_random_test gradient mismatch." + assert np.allclose(pt_cost, np_cost, atol=cost_threshold, rtol=rtol), "large_random_test costs mismatch." + assert np.allclose(ag_cost, np_cost, atol=cost_threshold, rtol=rtol), "large_random_test costs mismatch." + assert np.allclose(pt_grads, np_grads, atol=grad_threshold, rtol=rtol), "large_random_test gradient mismatch." + assert np.allclose(ag_grads, np_grads, atol=grad_threshold, rtol=rtol), "large_random_test gradient mismatch." @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) - def test_case_small_clamp(self, device): + @pytest.mark.parametrize('dtype', DTYPES) + def test_case_small_clamp(self, device, dtype): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) @@ -333,9 +366,13 @@ def test_case_small_clamp(self, device): [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]], ] ] - ) + ).astype(dtype) labels = [[1, 2]] + cost_threshold = 1e-8 if dtype == np.float32 else 5e-4 + grad_threshold = 1e-8 if dtype == np.float32 else 5e-5 + rtol = 1e-5 if dtype == np.float32 else 1e-3 + fn_pt = RNNTLossNumba(blank=0, reduction='sum', clamp=GRAD_CLAMP) pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) @@ -360,16 +397,17 @@ def test_case_small_clamp(self, device): ] ) - assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." - assert np.allclose(pt_grads, expected_grads), "small_test gradient mismatch." + assert np.allclose(pt_cost, expected_cost, atol=cost_threshold, rtol=rtol), "small_test costs mismatch." + assert np.allclose(pt_grads, expected_grads, atol=grad_threshold, rtol=rtol), "small_test gradient mismatch." - assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_test costs mismatch." - assert np.allclose(pt_grads, np_grads), "small_test gradient mismatch." + assert np.allclose(pt_cost, np_cost, atol=cost_threshold, rtol=rtol), "small_test costs mismatch." + assert np.allclose(pt_grads, np_grads, atol=grad_threshold, rtol=rtol), "small_test gradient mismatch." @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) + @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('fastemit_lambda', [1.0, 0.01, 0.00001]) - def test_case_small_fastemit_clamp(self, device, fastemit_lambda): + def test_case_small_fastemit_clamp(self, device, dtype, fastemit_lambda): if device == 'cuda': numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) @@ -381,9 +419,13 @@ def test_case_small_fastemit_clamp(self, device, fastemit_lambda): [[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.2, 0.1, 0.1], [0.7, 0.1, 0.2, 0.1, 0.1]], ] ] - ) + ).astype(dtype) labels = [[1, 2]] + cost_threshold = 1e-8 if dtype == np.float32 else 1e-3 + grad_threshold = 1e-8 if dtype == np.float32 else 5e-4 + rtol = 1e-5 if dtype == np.float32 else 1e-3 + fn_pt = RNNTLossNumba(blank=0, reduction='sum', fastemit_lambda=fastemit_lambda, clamp=GRAD_CLAMP) pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) @@ -393,9 +435,9 @@ def test_case_small_fastemit_clamp(self, device, fastemit_lambda): expected_cost = 4.495666 expected_cost += expected_cost * fastemit_lambda - assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "small_test costs mismatch." - assert np.allclose(pt_cost, np_cost, rtol=1e-6), "small_test costs mismatch." - assert np.allclose(pt_grads, np_grads), "small_test gradient mismatch." + assert np.allclose(pt_cost, expected_cost, atol=cost_threshold, rtol=rtol), "small_test costs mismatch." + assert np.allclose(pt_cost, np_cost, atol=cost_threshold, rtol=rtol), "small_test costs mismatch." + assert np.allclose(pt_grads, np_grads, atol=grad_threshold, rtol=rtol), "small_test gradient mismatch." @pytest.mark.unit @pytest.mark.parametrize('device', DEVICES) diff --git a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py index 230b6b7c099f..cb5a9816e237 100644 --- a/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py +++ b/tests/collections/asr/numba/rnnt_loss/utils/test_gpu_rnnt_kernel.py @@ -25,8 +25,14 @@ from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ +DTYPES = [torch.float32] +if numba_utils.is_numba_cuda_fp16_supported(): + DTYPES.append(torch.float16) + + def log_softmax(x, axis=-1): x = torch.from_numpy(x) # zero-copy + x = x.float() x = torch.log_softmax(x, dim=axis) x = x.numpy() return x @@ -42,12 +48,14 @@ def log_softmax_grad(x, axis=-1): class TestRNNTCUDAKernels: @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit - def test_compute_alphas_kernel(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_compute_alphas_kernel(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) random = np.random.RandomState(0) original_shape = [1, 5, 11, 3] B, T, U, V = original_shape + threshold = 1e-5 if dtype == torch.float32 else 3e-4 # Numpy kernel x = random.randn(*original_shape) @@ -67,7 +75,7 @@ def test_compute_alphas_kernel(self): else: stream = cuda.default_stream() - x_c = torch.tensor(x, device=device, dtype=torch.float32) + x_c = torch.tensor(x, device=device, dtype=dtype) labels_c = torch.tensor(labels, device=device, dtype=torch.int64) # Allocate workspace memory @@ -100,22 +108,24 @@ def test_compute_alphas_kernel(self): alphas = alphas.view([B, T, U]) diff = ground_alphas - alphas[0].cpu().numpy() - assert np.abs(diff).mean() <= 1e-5 - assert np.square(diff).mean() <= 1e-10 + assert np.abs(diff).mean() <= threshold + assert np.square(diff).mean() <= (threshold ** 2) ll_diff = ground_log_likelihood - llForward[0].cpu().numpy() - assert np.abs(ll_diff).mean() <= 1e-5 - assert np.square(ll_diff).mean() <= 1e-10 + assert np.abs(ll_diff).mean() <= threshold + assert np.square(ll_diff).mean() <= (threshold ** 2) @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit - def test_compute_betas_kernel(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_compute_betas_kernel(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) random = np.random.RandomState(0) original_shape = [1, 5, 11, 3] B, T, U, V = original_shape + threshold = 1e-5 if dtype == torch.float32 else 3e-4 # Numpy kernel x = random.randn(*original_shape) @@ -135,7 +145,7 @@ def test_compute_betas_kernel(self): else: stream = cuda.default_stream() - x_c = torch.tensor(x, device=device, dtype=torch.float32) + x_c = torch.tensor(x, device=device, dtype=dtype) labels_c = torch.tensor(labels, device=device, dtype=torch.int64) # Allocate workspace memory @@ -168,17 +178,18 @@ def test_compute_betas_kernel(self): betas = betas.view([B, T, U]) diff = ground_alphas - betas[0].cpu().numpy() - assert np.abs(diff).mean() <= 1e-5 - assert np.square(diff).mean() <= 1e-10 + assert np.abs(diff).mean() <= threshold + assert np.square(diff).mean() <= (threshold ** 2) ll_diff = ground_log_likelihood - llBackward[0].cpu().numpy() - assert np.abs(ll_diff).mean() <= 1e-5 - assert np.square(ll_diff).mean() <= 1e-10 + assert np.abs(ll_diff).mean() <= threshold + assert np.square(ll_diff).mean() <= (threshold ** 2) @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit - def test_compute_grads_kernel(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_compute_grads_kernel(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) fastemit_lambda = 0.0 @@ -187,6 +198,7 @@ def test_compute_grads_kernel(self): random = np.random.RandomState(0) original_shape = [1, 5, 11, 3] B, T, U, V = original_shape + threshold = 1e-5 if dtype == torch.float32 else 3e-5 # Numpy kernel x = random.randn(*original_shape) @@ -220,7 +232,7 @@ def test_compute_grads_kernel(self): else: stream = cuda.default_stream() - x_c = torch.tensor(x, device=device, dtype=torch.float32) + x_c = torch.tensor(x, device=device, dtype=dtype) labels_c = labels.clone().to(device=device, dtype=torch.int64) # Allocate workspace memory @@ -283,12 +295,13 @@ def test_compute_grads_kernel(self): grads = grads.view([B, T, U, V]) diff = true_grads - grads[0].cpu().numpy() - assert np.abs(diff).mean() <= 1e-5 - assert np.square(diff).mean() <= 1e-10 + assert np.abs(diff).mean() <= threshold + assert np.square(diff).mean() <= (threshold ** 2) * 5.0 @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit - def test_compute_grads_kernel_fastemit(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_compute_grads_kernel_fastemit(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) fastemit_lambda = 0.001 @@ -297,6 +310,7 @@ def test_compute_grads_kernel_fastemit(self): random = np.random.RandomState(0) original_shape = [1, 5, 11, 3] B, T, U, V = original_shape + threshold = 1e-5 if dtype == torch.float32 else 3e-5 # Numpy kernel x = random.randn(*original_shape) @@ -330,7 +344,7 @@ def test_compute_grads_kernel_fastemit(self): else: stream = cuda.default_stream() - x_c = torch.tensor(x, device=device, dtype=torch.float32) + x_c = torch.tensor(x, device=device, dtype=dtype) labels_c = labels.clone().to(device=device, dtype=torch.int64) # Allocate workspace memory @@ -393,12 +407,13 @@ def test_compute_grads_kernel_fastemit(self): grads = grads.view([B, T, U, V]) diff = true_grads - grads[0].cpu().numpy() - assert np.abs(diff).mean() <= 1e-5 - assert np.square(diff).mean() <= 1e-10 + assert np.abs(diff).mean() <= threshold + assert np.square(diff).mean() <= (threshold ** 2) * 5 @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit - def test_compute_grads_kernel_clamp(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_compute_grads_kernel_clamp(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) fastemit_lambda = 0.0 @@ -407,6 +422,7 @@ def test_compute_grads_kernel_clamp(self): random = np.random.RandomState(0) original_shape = [1, 5, 11, 3] B, T, U, V = original_shape + threshold = 1e-5 if dtype == torch.float32 else 3e-5 # Numpy kernel x = random.randn(*original_shape) @@ -440,7 +456,7 @@ def test_compute_grads_kernel_clamp(self): else: stream = cuda.default_stream() - x_c = torch.tensor(x, device=device, dtype=torch.float32) + x_c = torch.tensor(x, device=device, dtype=dtype) labels_c = labels.clone().to(device=device, dtype=torch.int64) # Allocate workspace memory @@ -503,8 +519,8 @@ def test_compute_grads_kernel_clamp(self): grads = grads.view([B, T, U, V]) diff = true_grads - grads[0].cpu().numpy() - assert np.abs(diff).mean() <= 1e-5 - assert np.square(diff).mean() <= 1e-10 + assert np.abs(diff).mean() <= threshold + assert np.square(diff).mean() <= (threshold ** 2) * 5 class TestTDTCUDAKernels: diff --git a/tests/collections/asr/numba/rnnt_loss/utils/test_reduce.py b/tests/collections/asr/numba/rnnt_loss/utils/test_reduce.py index 7c2ba6a41208..5994d53e1d8f 100644 --- a/tests/collections/asr/numba/rnnt_loss/utils/test_reduce.py +++ b/tests/collections/asr/numba/rnnt_loss/utils/test_reduce.py @@ -20,17 +20,22 @@ from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ +DTYPES = [np.float32] +if numba_utils.is_numba_cuda_fp16_supported(): + DTYPES.append(np.float16) + class TestRNNTCUDAReductions: @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit - def test_reduce_max(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_reduce_max(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) random = np.random.RandomState(0) original_shape = [1, 5, 4, 3] - x = random.randn(*original_shape).reshape([-1]) - dx = random.randn(*x.shape) + x = random.randn(*original_shape).reshape([-1]).astype(dtype) + dx = random.randn(*x.shape).astype(dtype) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -53,13 +58,14 @@ def test_reduce_max(self): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Reductions can only be run when CUDA is available") @pytest.mark.unit - def test_reduce_exp(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_reduce_exp(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) random = np.random.RandomState(0) original_shape = [1, 5, 4, 2] - x = random.randn(*original_shape).reshape([-1]) - dx = np.zeros_like(x) + x = random.randn(*original_shape).reshape([-1]).astype(dtype) + dx = np.zeros_like(x).astype(dtype) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) diff --git a/tests/collections/asr/numba/rnnt_loss/utils/test_rnnt_helper.py b/tests/collections/asr/numba/rnnt_loss/utils/test_rnnt_helper.py index 243fe727e172..08f12da8324d 100644 --- a/tests/collections/asr/numba/rnnt_loss/utils/test_rnnt_helper.py +++ b/tests/collections/asr/numba/rnnt_loss/utils/test_rnnt_helper.py @@ -20,11 +20,16 @@ from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ +DTYPES = [np.float32] +if numba_utils.is_numba_cuda_fp16_supported(): + DTYPES.append(np.float16) + class TestRNNTHelper: @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_log_sum_exp(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_log_sum_exp(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -34,8 +39,9 @@ def _kernel(x, y): if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.log_sum_exp(x[x_pos], y[x_pos]) - x = np.zeros([8]) # np.random.rand(8192) - y = np.ones([8]) # np.random.rand(8192) + x = np.zeros([8]).astype(dtype) # np.random.rand(8192) + y = np.ones([8]).astype(dtype) # np.random.rand(8192) + threshold = 1e-5 if dtype == np.float32 else 2e-3 stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -52,11 +58,12 @@ def _kernel(x, y): x_new = x_c.copy_to_host(stream=stream) del x_c, y_c - assert (x_new.sum() - 10.506093500145782) <= 1e-5 + assert (x_new.sum() - 10.506093500145782) <= threshold @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_log_sum_exp_neg_inf(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_log_sum_exp_neg_inf(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -66,8 +73,8 @@ def _kernel(x, y): if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.log_sum_exp(x[x_pos], y[x_pos]) - x = np.asarray([global_constants.FP32_NEG_INF] * 8) - y = np.ones([len(x)]) + x = np.asarray([global_constants.FP32_NEG_INF] * 8).astype(dtype) + y = np.ones([len(x)]).astype(dtype) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -88,7 +95,8 @@ def _kernel(x, y): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_div_up(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_div_up(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -98,8 +106,8 @@ def _kernel(x, y): if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.div_up(x[x_pos], y[x_pos]) - x = np.full([8], fill_value=10) # np.random.rand(8192) - y = np.full([8], fill_value=2) # np.random.rand(8192) + x = np.full([8], fill_value=10).astype(dtype) # np.random.rand(8192) + y = np.full([8], fill_value=2).astype(dtype) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -121,7 +129,8 @@ def _kernel(x, y): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_add(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_add(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -131,8 +140,8 @@ def _kernel(x, y): if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.add(x[x_pos], y[x_pos]) - x = np.full([8], fill_value=10) # np.random.rand(8192) - y = np.full([8], fill_value=2) # np.random.rand(8192) + x = np.full([8], fill_value=10).astype(dtype) # np.random.rand(8192) + y = np.full([8], fill_value=2).astype(dtype) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -154,7 +163,8 @@ def _kernel(x, y): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_maximum(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_maximum(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -164,8 +174,8 @@ def _kernel(x, y): if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.maximum(x[x_pos], y[x_pos]) - x = np.full([8], fill_value=10) # np.random.rand(8192) - y = np.full([8], fill_value=2) # np.random.rand(8192) + x = np.full([8], fill_value=10).astype(dtype) # np.random.rand(8192) + y = np.full([8], fill_value=2).astype(dtype) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -187,7 +197,8 @@ def _kernel(x, y): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_identity(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_identity(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -197,7 +208,7 @@ def _kernel(x): if x_pos < x.shape[0]: x[x_pos] = rnnt_helper.identity(x[x_pos]) - x = np.full([8], fill_value=10) # np.random.rand(8192) + x = np.full([8], fill_value=10).astype(dtype) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -218,7 +229,8 @@ def _kernel(x): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_negate(self): + @pytest.mark.parametrize('dtype', [np.float32, np.float16]) + def test_negate(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -228,7 +240,7 @@ def _kernel(x): if x_pos < x.shape[0]: x[x_pos] = rnnt_helper.negate(x[x_pos]) - x = np.full([8], fill_value=10) # np.random.rand(8192) + x = np.full([8], fill_value=10).astype(dtype) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -249,7 +261,8 @@ def _kernel(x): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_exponential(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_exponential(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -259,7 +272,7 @@ def _kernel(x): if x_pos < x.shape[0]: x[x_pos] = rnnt_helper.exponential(x[x_pos]) - x = np.random.rand(8) + x = np.random.rand(8).astype(dtype) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -281,7 +294,8 @@ def _kernel(x): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.unit - def test_log_plus(self): + @pytest.mark.parametrize('dtype', DTYPES) + def test_log_plus(self, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) # wrapper kernel for device function that is tested @@ -291,8 +305,8 @@ def _kernel(x, y): if x_pos < x.shape[0] and x_pos < y.shape[0]: x[x_pos] = rnnt_helper.log_plus(x[x_pos], y[x_pos]) - x = np.full([8], fill_value=10.0) # np.random.rand(8192) - y = np.full([8], fill_value=2.0) # np.random.rand(8192) + x = np.full([8], fill_value=10.0).astype(dtype) # np.random.rand(8192) + y = np.full([8], fill_value=2.0).astype(dtype) # np.random.rand(8192) stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -317,12 +331,15 @@ def _kernel(x, y): @pytest.mark.skipif(not cuda.is_available(), reason="CUDA Helpers can only be run when CUDA is available") @pytest.mark.parametrize('batch_size', [8, 128, 256]) @pytest.mark.parametrize('fastemit_lambda', [0.0, 0.001]) + @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.unit - def test_compute_costs_data(self, batch_size, fastemit_lambda): + def test_compute_costs_data(self, batch_size, fastemit_lambda, dtype): numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + np.random.seed(0) x = np.full([batch_size], fill_value=0.0) # np.random.rand(8192) - y = np.random.randn(batch_size) # np.random.rand(8192) + y = np.random.randn(batch_size).astype(dtype) # np.random.rand(8192) + threshold = 1e-5 if dtype == np.float32 else 1e-5 stream = cuda.stream() x_c = cuda.to_device(x, stream=stream) @@ -340,11 +357,11 @@ def test_compute_costs_data(self, batch_size, fastemit_lambda): x_new = x_c.copy_to_host(stream=stream) del x_c, y_c - res = -(y.copy()) + res = -(y.astype(np.float32).copy()) res *= 1.0 + fastemit_lambda for i in range(len(x_new)): - assert x_new[i] == res[i], f"index failed {i}" + assert abs(x_new[i] - res[i]) < threshold, f"index failed {i}" if __name__ == '__main__':