-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
77 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
import unittest | ||
import torch | ||
from ditorch.utils import is_to_fp32_tensor | ||
|
||
|
||
@is_to_fp32_tensor(to_fp32=True) | ||
def test_func_to_fp32(tensor: torch.Tensor, tensors_list, tensors_dict): | ||
assert ( | ||
tensor.dtype == torch.float32 | ||
), f"tensor's dtype is not fp32, but {tensor.dtype}" | ||
for tensor in tensors_list: | ||
if isinstance(v, torch.Tensor): | ||
assert ( | ||
tensor.dtype == torch.float32 | ||
), f"tensor's dtype is not fp32, but {tensor.dtype}" | ||
for k, v in tensors_dict.items(): | ||
if isinstance(v, torch.Tensor): | ||
assert ( | ||
v.dtype == torch.float32 | ||
), f"tensor's dtype is not fp32, but {v.dtype}" | ||
|
||
|
||
class TestUtils(unittest.TestCase): | ||
def test_is_to_fp32_tensor(): | ||
tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16) | ||
tensors_list = [ | ||
torch.tensor([1.0, 1.0, 1.0], dtype=torch.float16), | ||
torch.tensor([2.0, 2.0, 2.0], dtype=torch.float16), | ||
torch.tensor([3.0, 3.0, 3.0], dtype=torch.float16), | ||
] | ||
tensors_dict = { | ||
"tensor1": torch.tensor([1.0, 1.0, 1.0], dtype=torch.float16), | ||
"tensor2": torch.tensor([2.0, 2.0, 2.0], dtype=torch.float16), | ||
"tensor3": torch.tensor([3.0, 3.0, 3.0], dtype=torch.float16), | ||
} | ||
|
||
test_func_to_fp32(tensor, tensors_list, tensors_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) 2024, DeepLink. | ||
|
||
import functools | ||
import torch | ||
|
||
def to_fp32_if_tensor(in_tensors): | ||
if isinstance(in_tensors, torch.Tensor): | ||
return in_tensors.to(torch.float32) | ||
elif isinstance(in_tensors, (list, tuple)): | ||
return [to_fp32_if_tensor(tensor) for tensor in in_tensors] | ||
elif isinstance(in_tensors, dict): | ||
return {k: to_fp32_if_tensor(v) for k, v in in_tensors.items()} | ||
else: | ||
return in_tensors | ||
|
||
|
||
def is_to_fp32_tensor(to_fp32: bool): | ||
def to_fp32_wrapper(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
# Convert positional arguments to fp32 if possible | ||
args_fp32 = to_fp32_if_tensor(args) | ||
# Convert keyword arguments to fp32 if possible | ||
kwargs_fp32 = to_fp32_if_tensor(kwargs) | ||
return func(*args_fp32, **kwargs_fp32) | ||
if to_fp32: | ||
print(f"{func.__name__} mocked by fp32.") | ||
return wrapper | ||
else: | ||
return func | ||
return to_fp32_wrapper | ||
|