Skip to content

Commit

Permalink
[PIR save/load] Open more tests for paddle.save and paddle.load (#64044)
Browse files Browse the repository at this point in the history
* open more tests for paddle.save and paddle.load

* fix
  • Loading branch information
changeyoung98 authored May 7, 2024
1 parent 012b9df commit 4d8c08f
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 43 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ void BindProgram(py::module *m) {
for (auto op : self->block()->ops()) {
for (auto var : op->results()) {
auto is_persistable =
var.attribute<BoolAttribute>("persistable");
var.attribute<BoolAttribute>(kAttrIsPersistable);
if (is_persistable && is_persistable.data()) {
if (var.defining_op()->isa<::pir::ParameterOp>()) {
std::string var_name = GetValueName(var);
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ def save(obj, path, protocol=4, **configs):
"'pickle_protocol' is a deprecated argument. Please use 'protocol' instead."
)

if isinstance(obj, Program):
if isinstance(obj, paddle.static.Program):
if in_pir_mode():
paddle.core.serialize_pir_program(
obj, path, 1, True, False, True
Expand Down Expand Up @@ -1200,6 +1200,12 @@ def load(path, **configs):
return tensor
except:
try:
if in_pir_mode():
program = paddle.static.Program()
paddle.core.deserialize_pir_program(
path, program, 1
)
return program
with _open_file_buffer(path, "rb") as f:
program_desc_str = f.read()
program = Program.parse_from_string(
Expand Down
34 changes: 34 additions & 0 deletions python/paddle/framework/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from paddle.base.framework import Parameter, Variable, static_only
from paddle.base.log_helper import get_logger
from paddle.base.wrapped_decorator import signature_safe_contextmanager
from paddle.framework import in_pir_mode

_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
Expand Down Expand Up @@ -323,3 +324,36 @@ def set_value(var, value, scope=None):
place = core.CUDAPlace(p.gpu_device_id())

t.set(value, place)


def get_value(var, scope=None):
"""
Get the value of variable or value in given scope.
Args:
scope(Scope, optional) : If `scope` is None, it will be set to global scope
obtained through 'paddle.static.global_scope()'. Otherwise, use `scope`.
Default: None
Returns:
Tensor, the value in given scope.
"""
if scope is not None and not isinstance(scope, core._Scope):
raise TypeError(
f"`scope` should be None or `paddle.static.Scope` type, but received {type(scope)}."
)

if scope is None:
scope = global_scope()
var_temp = scope.find_var(var.name)
if var_temp is None:
raise ValueError(f"Can not find Variable '{var.name}' in the Scope.")
t = var_temp.get_tensor()
return t


def is_pir_fetch_var(value):
if in_pir_mode() and value.get_defining_op().name() == "pd_op.fetch":
return True
return False
31 changes: 22 additions & 9 deletions test/deprecated/legacy_test/test_paddle_save_load_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import paddle
from paddle import base
from paddle.base import framework
from paddle.framework.io_utils import get_value, is_pir_fetch_var, set_value
from paddle.pir_utils import test_with_pir_api

IMAGE_SIZE = 784

Expand All @@ -42,6 +44,8 @@ def set_zero(self, prog, place, scope=None):
scope = base.global_scope()
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if is_pir_fetch_var(var):
continue
ten = scope.find_var(var.name).get_tensor()
if ten is not None:
ten.set(np.zeros_like(np.array(ten)), place)
Expand All @@ -55,7 +59,7 @@ def predicate(var):
vars = filter(predicate, program.list_vars())
for var in vars:
paddle.save(
var.get_value(),
get_value(var),
os.path.join(dirname, var.name),
use_binary_format=True,
)
Expand All @@ -68,8 +72,9 @@ def predicate(var):
for var in var_list:
var_load = paddle.load(os.path.join(dirname, var.name))
# set var_load to scope
var.set_value(var_load)
set_value(var, var_load)

@test_with_pir_api
def test_replace_save_load_vars(self):
paddle.enable_static()
with new_program_scope():
Expand All @@ -91,6 +96,8 @@ def test_replace_save_load_vars(self):
base_map = {}
for var in prog.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
if is_pir_fetch_var(var):
continue
t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
Expand All @@ -112,7 +119,7 @@ def test_replace_save_load_vars(self):
)

for var in prog.list_vars():
if var.persistable:
if var.persistable and not is_pir_fetch_var(var):
new_t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
Expand All @@ -129,14 +136,15 @@ def test_replace_save_load_vars(self):
self.set_zero(prog, place)
self.replace_load_vars(prog, path_vars2)
for var in prog.list_vars():
if var.persistable:
if var.persistable and not is_pir_fetch_var(var):
new_t = np.array(
base.global_scope().find_var(var.name).get_tensor()
)
base_t = base_map[var.name]

np.testing.assert_array_equal(new_t, base_t)

@test_with_pir_api
def test_save_load_lod_tensor(self):
paddle.enable_static()
OUTPUT_NUM = 32
Expand All @@ -149,7 +157,7 @@ def test_save_load_lod_tensor(self):
OUTPUT_NUM,
name='fc_vars',
)
prog = base.default_main_program()
prog = paddle.static.default_main_program()
place = (
base.CPUPlace()
if not paddle.base.core.is_compiled_with_cuda()
Expand All @@ -167,15 +175,15 @@ def test_save_load_lod_tensor(self):
IMAGE_SIZE,
OUTPUT_NUM,
]:
tensor = var.get_value()
tensor = get_value(var)
paddle.save(
tensor, dirname + 'fc_vars.w_0', use_binary_format=True
)
break

origin = np.array(var.get_value())
var.set_value(np.zeros_like(origin))
is_zeros = np.array(var.get_value())
origin = np.array(get_value(var))
set_value(var, np.zeros_like(origin))
is_zeros = np.array(get_value(var))

loaded_tensor = paddle.load(dirname + 'fc_vars.w_0')
self.assertTrue(isinstance(loaded_tensor, base.core.LoDTensor))
Expand Down Expand Up @@ -234,6 +242,7 @@ def test_save_load_lod_tensor(self):
with self.assertRaises(NotImplementedError):
paddle.framework.io._load_lod_tensor(1)

@test_with_pir_api
def test_save_load_selected_rows(self):
paddle.enable_static()
place = (
Expand Down Expand Up @@ -299,3 +308,7 @@ def test_save_load_selected_rows(self):
paddle.framework.io._save_selected_rows(selected_rows, 1)
with self.assertRaises(NotImplementedError):
paddle.framework.io._load_selected_rows(1)


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 4d8c08f

Please sign in to comment.