Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions optimum/habana/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,35 +40,35 @@ def get_dtype(logits: Union[torch.Tensor, Tuple[torch.Tensor]]) -> Union[str, Li
logits_dtype = "float32"
return logits_dtype
elif isinstance(logits, tuple):
return [get_dtype(logits_tensor) for logits_tensor in logits]
return get_dtype(logits[0])
elif isinstance(logits, dict):
return {k: get_dtype(v) for k, v in logits.items()}
else:
raise TypeError(f"logits should be of type torch.Tensor or tuple, got {type(logits)} which is not supported")


def convert_into_dtypes(
preds: Union[np.ndarray, Tuple[np.ndarray]], dtypes: Union[str, List[str]]
preds: Union[np.ndarray, Tuple[np.ndarray]], dtype: str
) -> Union[np.ndarray, Tuple[np.ndarray]]:
"""
Convert preds into dtypes.
Convert preds into the target dtype.

Args:
preds (Union[np.ndarray, Tuple[np.ndarray]]): predictions to convert
dtypes (Union[str, List[str]]): dtypes used for the conversion
dtype (str): dtype used for the conversion

Raises:
TypeError: only torch.Tensor and tuple are supported
TypeError: only np.ndarray and tuple are supported

Returns:
Union[np.ndarray, Tuple[np.ndarray]]: converted preds
"""
if isinstance(preds, np.ndarray):
if preds.dtype == dtypes:
if preds.dtype == dtype:
return preds
else:
return preds.astype(dtypes)
return preds.astype(dtype)
elif isinstance(preds, tuple):
return tuple(convert_into_dtypes(preds_tensor, dtypes[i]) for i, preds_tensor in enumerate(preds))
return tuple(convert_into_dtypes(preds_tensor, dtype) for preds_tensor in preds)
else:
raise TypeError(f"preds should be of type np.ndarray or tuple, got {type(preds)} which is not supported")