From 9c0c3fd14844c750c78e971f1388721105720e4e Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Tue, 5 Mar 2024 07:23:10 +0000 Subject: [PATCH] Fix `get_dtype` and `convert_into_dtypes` --- optimum/habana/transformers/trainer_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/optimum/habana/transformers/trainer_utils.py b/optimum/habana/transformers/trainer_utils.py index e7dc5eaf2f..edc6bfe29e 100644 --- a/optimum/habana/transformers/trainer_utils.py +++ b/optimum/habana/transformers/trainer_utils.py @@ -40,7 +40,7 @@ def get_dtype(logits: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Union[str, Li logits_dtype = "float32" return logits_dtype elif isinstance(logits, tuple): - return [get_dtype(logits_tensor) for logits_tensor in logits] + return get_dtype(logits[0]) elif isinstance(logits, dict): return {k: get_dtype(v) for k, v in logits.items()} else: @@ -48,27 +48,27 @@ def get_dtype(logits: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Union[str, Li def convert_into_dtypes( - preds: Union[np.ndarray, Tuple[np.ndarray]], dtypes: Union[str, List[str]] + preds: Union[np.ndarray, Tuple[np.ndarray]], dtype: str ) -> Union[np.ndarray, Tuple[np.ndarray]]: """ - Convert preds into dtypes. + Convert preds into the target dtype. Args: preds (Union[np.ndarray, Tuple[np.ndarray]]): predictions to convert - dtypes (Union[str, List[str]]): dtypes used for the conversion + dtype (str): dtype used for the conversion Raises: - TypeError: only torch.Tensor and tuple are supported + TypeError: only np.ndarray and tuple are supported Returns: Union[np.ndarray, Tuple[np.ndarray]]: converted preds """ if isinstance(preds, np.ndarray): - if preds.dtype == dtypes: + if preds.dtype == dtype: return preds else: - return preds.astype(dtypes) + return preds.astype(dtype) elif isinstance(preds, tuple): - return tuple(convert_into_dtypes(preds_tensor, dtypes[i]) for i, preds_tensor in enumerate(preds)) + return tuple(convert_into_dtypes(preds_tensor, dtype) for preds_tensor in preds) else: raise TypeError(f"preds should be of type np.ndarray or tuple, got {type(preds)} which is not supported")