Skip to content

Commit

Permalink
Reduce memory usage and improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoguochun1995 committed Oct 10, 2024
1 parent a747b2f commit ce814ca
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 33 deletions.
4 changes: 2 additions & 2 deletions op_tools/op_autocompare_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def compare_forward_relate(self):
if len(self.dtype_cast_dict) > 0:
dtype_cast_info = f"cpu_dtype_cast_info(from:to): {self.dtype_cast_dict}"
print("\n" * 2)
print(f"{self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(f"autocompare {self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(f"{self.current_location}")
print(self.op_forward_args_to_table())
print(dict_data_list_to_table(result_list))
Expand Down Expand Up @@ -272,7 +272,7 @@ def compare_backward_relate(self):
dtype_cast_info = f"cpu_dtype_cast_info(from:to): {self.dtype_cast_dict}"

print("\n" * 2)
print(f"{self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(f"autocompare {self.name} forward_id: {self.forward_op_id} {dtype_cast_info}")
print(f"{self.current_location}")
print(self.backward_args_table)
print(dict_data_list_to_table(backward_compare_result["result_list"]))
Expand Down
92 changes: 66 additions & 26 deletions op_tools/op_dtype_cast_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
traverse_container,
get_dtype_cast_dict_form_str,
is_opname_match,
is_view_op,
is_dtype_cast_op,
garbage_collect
)
from .pretty_print import dict_data_list_to_table
Expand All @@ -26,10 +28,21 @@ def before_call_op(self, *args, **kwargs):
)
self.dtype_cast_dict = get_dtype_cast_dict_form_str(self.dtype_cast_config_str)
with DisableHookGuard():
self.args_raw = self.args
self.is_cpu_op, self.device = is_cpu_op(*args, **kwargs)
if self.is_cpu_op:
return

self.raw_ins_dtype_list = []
for arg in traverse_container(self.args):
if isinstance(arg, torch.Tensor):
self.raw_ins_dtype_list.append(arg.dtype)
else:
self.raw_ins_dtype_list.append(None)

for dtype in set(self.dtype_cast_dict.keys()):
if dtype not in self.raw_ins_dtype_list:
self.dtype_cast_dict.pop(dtype)

self.args = to_device(
self.device,
self.args,
Expand All @@ -43,56 +56,83 @@ def before_call_op(self, *args, **kwargs):
detach=False,
)
self.dtype_cast_back_dict = {}
self.ins_list = []
self.ins_dtype_list = []
for arg in traverse_container(self.args):
self.ins_list.append(arg)

self.raw_ins_list = []
for arg in traverse_container(self.args_raw):
self.raw_ins_list.append(arg)
if isinstance(arg, torch.Tensor):
self.ins_dtype_list.append(arg.dtype)
else:
self.ins_dtype_list.append(None)

self.data_dict_list = []
for i in range(len(self.ins_list)):
if isinstance(self.ins_list[i], torch.Tensor):
if self.ins_list[i].dtype != self.raw_ins_list[i].dtype:
self.dtype_cast_back_dict[self.ins_list[i].dtype] = self.raw_ins_list[i].dtype
data_dict = {
"name": self.name,
"target": f"input[{i}]",
"action": f"{self.raw_ins_list[i].dtype} -> {self.ins_list[i].dtype}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)
for i in range(len(self.ins_dtype_list)):
if self.ins_dtype_list[i] != self.raw_ins_dtype_list[i]:
self.dtype_cast_back_dict[self.ins_dtype_list[i]] = self.raw_ins_dtype_list[i]
data_dict = {
"name": self.name,
"target": f"input[{i}]",
"action": f"{self.raw_ins_dtype_list[i]} -> {self.ins_dtype_list[i]}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)

def after_call_op(self, result):
if self.is_cpu_op:
return
with DisableHookGuard():
self.result_raw = result
self.raw_result_dtype_list = []
for arg in traverse_container(self.result):
if isinstance(arg, torch.Tensor):
self.raw_result_dtype_list.append(arg.dtype)
else:
self.raw_result_dtype_list.append(None)

self.result = to_device(
self.device,
self.result,
dtype_cast_dict=self.dtype_cast_back_dict,
detach=False,
)

self.result_dtype_list = []
for arg in traverse_container(self.result):
if isinstance(arg, torch.Tensor):
self.result_dtype_list.append(arg.dtype)
else:
self.result_dtype_list.append(None)

i = -1
for out in traverse_container(self.result_raw):
for out in traverse_container(self.raw_result_dtype_list):
i += 1
if isinstance(out, torch.Tensor) and out.dtype in self.dtype_cast_back_dict.keys():
if out in self.dtype_cast_back_dict.keys():
data_dict = {
"name": self.name,
"target": f"output[{i}]",
"action": f"{out.dtype} -> {self.dtype_cast_back_dict[out.dtype]}",
"action": f"{out} -> {self.dtype_cast_back_dict[out]}",
"config": self.dtype_cast_config_str,
}
self.data_dict_list.append(data_dict)
if len(self.data_dict_list) > 0:
print(dict_data_list_to_table(self.data_dict_list))
self = None
garbage_collect()
if len(self.data_dict_list) > 0:
print("\n" * 2, f"cast_dtype {self.name} forward_id: {self.id}")
print(f"{self.current_location}")
print(dict_data_list_to_table(self.data_dict_list))
print("\n" * 2)
result = self.result
self = None
garbage_collect()
return result

def is_should_apply(self, *args, **kwargs):
if is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_DISABLE_LIST", "")):
return False

if is_view_op(self.name):
return False

if is_dtype_cast_op(self.name, *args, **kwargs):
return False

return True

def is_should_apply_backward(self, *args, **kwargs):

return is_opname_match(self.name, os.getenv("OP_DTYPE_CAST_LIST", ".*"))
2 changes: 1 addition & 1 deletion op_tools/op_fallback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def dump_op_args(self):
dtype_cast_info = "cpu_dtype_cast_info: " + str(self.dtype_cast_dict)

print("\n" * 2)
print(f"{self.name} forward_id: {self.id} {dtype_cast_info}")
print(f"fallback {self.name} forward_id: {self.id} {dtype_cast_info}")
print(f"{self.current_location}")
print(table)
print("\n" * 2)
Expand Down
11 changes: 11 additions & 0 deletions op_tools/test/test_opname_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ def test_get_dtype_cast_dict_from_config(self):
},
)

def test_get_dtype_cast_dict_from_config2(self):
dtype_cast_dict = get_dtype_cast_dict_form_str(" torch.float32 ->torch.float16, torch.float64 -> torch.float16, torch.int64-> torch.int32 ")
self.assertEqual(
dtype_cast_dict,
{
torch.float32: torch.float16,
torch.float64: torch.float16,
torch.int64: torch.int32,
},
)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions op_tools/test/test_tool_with_special_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def test_exp(self):
y = x.exp()
y.backward(torch.ones_like(y))

def test_dtype_cast(self):
with op_tools.OpDtypeCast():
x = torch.randn(3, 4, 5, dtype=torch.float16, device="cuda", requires_grad=True)
y = x.to(torch.float32)
y.backward(torch.ones_like(y))


if __name__ == "__main__":
unittest.main()
30 changes: 26 additions & 4 deletions op_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def is_inplace_op(name):

def get_function_from_string(func_str):
parts = func_str.split(".")
attrs = [importlib.import_module(parts[0])]
attrs = [importlib.import_module(parts[0].strip())]
for i in range(0, len(parts) - 1):
attr = getattr(attrs[i], parts[i + 1])
attr = getattr(attrs[i], parts[i + 1].strip())
attrs.append(attr)

return attrs[len(parts) - 1]
Expand All @@ -101,6 +101,7 @@ def get_dtype_cast_dict_form_str(config):
dtype_cast_dict = dict()
if config is not None:
for item in config.split(","):
item = item.strip()
dtype_cast_dict[get_function_from_string(item.split("->")[0])] = get_function_from_string(item.split("->")[1])
return dtype_cast_dict

Expand Down Expand Up @@ -193,6 +194,27 @@ def is_random_number_gen_op(name):
return name in RANDOM_NUMBER_GEN_OPS


def is_dtype_cast_op(name, *args, **kwargs):
if "dtype" in kwargs.keys() and kwargs["dtype"] is not None:
return True
for arg in args:
if isinstance(arg, torch.dtype):
return True
dtype_cast_op = [
"torch.Tensor.double",
"torch.Tensor.float",
"torch.Tensor.half",
"torch.Tensor.bfloat16",
"torch.Tensor.long",
"torch.Tensor.int",
"torch.Tensor.short",
"torch.Tensor.bool",
]
if name in dtype_cast_op:
return True
return False


def tensor_max_diff(a, b):
a_cpu, b_cpu = a.cpu(), b.cpu()
if a_cpu.dtype == torch.bool:
Expand Down Expand Up @@ -386,7 +408,7 @@ def __init__(self) -> None:
def is_shoule_collect(self):
self.current_rss = psutil.Process().memory_info().rss
self.current_device_memory_usage = torch.cuda.memory_allocated()
if self.current_rss - self.rss > self.max_diff:
if (self.current_rss - self.rss > self.max_diff) or (self.current_device_memory_usage - self.device_memory_usage > self.max_diff):
return True
else:
return False
Expand All @@ -396,7 +418,7 @@ def collect(self):
self.rss = max(self.rss, psutil.Process().memory_info().rss)
self.device_memory_usage = max(self.device_memory_usage, torch.cuda.memory_allocated())
print(
f"GarbageCollectEvaluate: after collect : rss: {self.rss >> 20} current_rss: {self.current_rss >> 20} max_diff: {self.max_diff} device_memory_usage: {self.device_memory_usage >> 20} current_device_memory_usage: {self.current_device_memory_usage >> 20}" # noqa: E501
f"GarbageCollectEvaluate: after collect : rss: {self.rss >> 20} MB, current_rss: {self.current_rss >> 20} MB, max_diff: {self.max_diff>>20} MB, device_memory_usage: {self.device_memory_usage >> 20} MB, current_device_memory_usage: {self.current_device_memory_usage >> 20} MB" # noqa: E501
)


Expand Down

0 comments on commit ce814ca

Please sign in to comment.