Skip to content

Commit

Permalink
dist op can be mocked by fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
yangbofun committed Dec 19, 2024
1 parent 46473f5 commit 04d18d1
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
39 changes: 39 additions & 0 deletions ditorch/test/individual/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024, DeepLink.

import unittest
import torch
from ditorch.utils import is_to_fp32_tensor


@is_to_fp32_tensor(to_fp32=True)
def test_func_to_fp32(tensor: torch.Tensor, tensors_list, tensors_dict):
assert (
tensor.dtype == torch.float32
), f"tensor's dtype is not fp32, but {tensor.dtype}"
for tensor in tensors_list:
if isinstance(v, torch.Tensor):
assert (
tensor.dtype == torch.float32
), f"tensor's dtype is not fp32, but {tensor.dtype}"
for k, v in tensors_dict.items():
if isinstance(v, torch.Tensor):
assert (
v.dtype == torch.float32
), f"tensor's dtype is not fp32, but {v.dtype}"


class TestUtils(unittest.TestCase):
def test_is_to_fp32_tensor():
tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float16)
tensors_list = [
torch.tensor([1.0, 1.0, 1.0], dtype=torch.float16),
torch.tensor([2.0, 2.0, 2.0], dtype=torch.float16),
torch.tensor([3.0, 3.0, 3.0], dtype=torch.float16),
]
tensors_dict = {
"tensor1": torch.tensor([1.0, 1.0, 1.0], dtype=torch.float16),
"tensor2": torch.tensor([2.0, 2.0, 2.0], dtype=torch.float16),
"tensor3": torch.tensor([3.0, 3.0, 3.0], dtype=torch.float16),
}

test_func_to_fp32(tensor, tensors_list, tensors_dict)
7 changes: 6 additions & 1 deletion ditorch/torch_npu_adapter/mock_dist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) 2024, DeepLink.
import functools
import torch
import torch.nn.functional as F # noqa
import torch.distributed as dist
from ditorch.utils import is_to_fp32_tensor


def copy_inp(tensor_dest, tensor_src):
Expand All @@ -17,13 +19,14 @@ def copy_inp(tensor_dest, tensor_src):
tensor_dest.copy_(tensor_src)


def mock_dist():
def mock_dist(use_fp32=False):
dist_all_reduce = dist.all_reduce
dist_reduce = dist.reduce
dist__reduce_scatter_base = dist._reduce_scatter_base
dist_reduce_scatter_tensor = dist.reduce_scatter_tensor
dist_reduce_scatter = dist.reduce_scatter

@is_to_fp32_tensor(use_fp32)
def dist_reduce_npu(tensor, dst, op=dist.ReduceOp.SUM, group=None, async_op=False):
if op == dist.ReduceOp.AVG:
handle = dist_reduce(tensor, dst, op=dist.ReduceOp.SUM, group=group, async_op=async_op)
Expand All @@ -37,6 +40,7 @@ def dist_reduce_npu(tensor, dst, op=dist.ReduceOp.SUM, group=None, async_op=Fals
handle = dist_reduce(tensor, op=op, group=group, async_op=async_op)
return handle

@is_to_fp32_tensor(use_fp32)
def dist_all_reduce_npu(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False):
if op == dist.ReduceOp.AVG:
handle = dist_all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op)
Expand All @@ -49,6 +53,7 @@ def dist_all_reduce_npu(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False
handle = dist_all_reduce(tensor, op=op, group=group, async_op=async_op)
return handle

@is_to_fp32_tensor(use_fp32)
def dist__reduce_scatter_base_npu(dist_reduce_scatter_func, output, input, op=dist.ReduceOp.SUM, group=None, async_op=False):
if op == dist.ReduceOp.AVG:
handle = dist_reduce_scatter_func(output, input, op=dist.ReduceOp.SUM, group=group, async_op=async_op)
Expand Down
32 changes: 32 additions & 0 deletions ditorch/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2024, DeepLink.

import functools
import torch

def to_fp32_if_tensor(in_tensors):
if isinstance(in_tensors, torch.Tensor):
return in_tensors.to(torch.float32)
elif isinstance(in_tensors, (list, tuple)):
return [to_fp32_if_tensor(tensor) for tensor in in_tensors]
elif isinstance(in_tensors, dict):
return {k: to_fp32_if_tensor(v) for k, v in in_tensors.items()}
else:
return in_tensors


def is_to_fp32_tensor(to_fp32: bool):
def to_fp32_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Convert positional arguments to fp32 if possible
args_fp32 = to_fp32_if_tensor(args)
# Convert keyword arguments to fp32 if possible
kwargs_fp32 = to_fp32_if_tensor(kwargs)
return func(*args_fp32, **kwargs_fp32)
if to_fp32:
print(f"{func.__name__} mocked by fp32.")
return wrapper
else:
return func
return to_fp32_wrapper

0 comments on commit 04d18d1

Please sign in to comment.