Skip to content

Commit cfc24b1

Browse files
committed
[XPU], support unified ckpt function
1 parent 7551730 commit cfc24b1

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

paddlenlp/trainer/plugins/unified_checkpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1307,7 +1307,7 @@ def check_unified_checkpoint(args, model, resume_from_checkpoint, safe_serializa
13071307
else:
13081308
local_resume = False
13091309
local_resume = paddle.to_tensor([local_resume])
1310-
dist.all_reduce(local_resume, op=dist.ReduceOp.PROD)
1310+
dist.all_reduce(local_resume, op=dist.ReduceOp.MIN)
13111311
local_resume = local_resume.item()
13121312
return local_resume
13131313

@@ -1425,7 +1425,7 @@ def check_dynamic_load(args, weight_map, existed_files, is_master_weights=False,
14251425
else:
14261426
local_resume = False
14271427
local_resume = paddle.to_tensor([local_resume])
1428-
dist.all_reduce(local_resume, op=dist.ReduceOp.PROD)
1428+
dist.all_reduce(local_resume, op=dist.ReduceOp.MIN)
14291429
return local_resume.item()
14301430

14311431
# check whether the optimizer checkpoint files are complete.

paddlenlp/trainer/trainer.py

+6
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,12 @@ def _load_rng_state(self, checkpoint):
17931793
for i in range(core.get_cuda_device_count()):
17941794
core.default_cuda_generator(i).set_state(checkpoint_rng_state["cuda"][i])
17951795

1796+
if core.is_compiled_with_xpu():
1797+
if not len(checkpoint_rng_state["cuda"]) == core.get_cuda_device_count():
1798+
raise ValueError("Length of xpu state list shoule be equal to the xpu device count")
1799+
for i in range(core.get_xpu_device_count()):
1800+
core.default_xpu_generator(i).set_state(checkpoint_rng_state["xpu"][i])
1801+
17961802
if paddle.device.get_all_custom_device_type() is not None:
17971803
custom_device_type = paddle.device.get_all_custom_device_type()
17981804
for device in custom_device_type:

0 commit comments

Comments
 (0)