Skip to content

Commit 257a5bc

Browse files
committed
[FIX DDP] fix ddp (PaddlePaddle#8549)
* enable trainer tests.
1 parent 02fd721 commit 257a5bc

File tree

5 files changed

+88
-86
lines changed

5 files changed

+88
-86
lines changed

paddlenlp/trainer/trainer.py

+2-10
Original file line numberDiff line numberDiff line change
@@ -1771,16 +1771,8 @@ def _wrap_model(self, model, training=True):
17711771
in_sep_parallel_mode = self.args.sep_parallel_degree > 1
17721772

17731773
# Multi-gpu training
1774-
if (
1775-
self.args.world_size > 1
1776-
and not self.args.use_hybrid_parallel
1777-
or not (
1778-
in_pipeline_parallel_mode
1779-
or in_sharding_parallel_mode
1780-
or in_tensor_parallel_mode
1781-
or in_sep_parallel_mode
1782-
)
1783-
):
1774+
if self.args.world_size > 1 and (not self.args.use_hybrid_parallel):
1775+
# MOE use DDP to broadcaset parameters.
17841776
model = paddle.DataParallel(model)
17851777
# Distributed training (should be after fp16 initialization)
17861778

paddlenlp/trainer/training_args.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,7 @@ def is_segment_parallel_supported():
14061406
if world_size > 1:
14071407
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized():
14081408
if self.unified_checkpoint:
1409-
self.use_hybrid_parallel = True
1409+
# DP use hybrid group
14101410
strategy = fleet.DistributedStrategy()
14111411
fleet.init(is_collective=True, strategy=strategy)
14121412
else:

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ exclude = ['.flake8']
1010

1111
[tool.pytest.ini_options]
1212
minversion = "6.0"
13-
addopts = "-ra -q --ignore model_zoo/gpt-3/"
13+
addopts = "-ra -q --dist loadgroup"
1414
pythonpath = ["."]
1515
testpaths = [
1616
"tests/data",
@@ -28,7 +28,7 @@ testpaths = [
2828
"tests/prompt",
2929
# "tests/taskflow", TODO (paddle 2.5.1 breaks this test suite, debug later)
3030
"tests/utils",
31-
"model_zoo",
31+
# "model_zoo",
3232
]
3333
python_files = [
3434
"test.py",

tests/trainer/test_lora_unified_checkpoint.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __test__(cls):
149149

150150
def setUp(self):
151151
"""
152-
1. update runfrist and rerun to run defined different config
152+
1. update runfirst and rerun to run defined different config
153153
2. update need_allclose to True if you want to check the result
154154
3. update rtol to the relative value you want to check
155155
"""
@@ -169,7 +169,7 @@ def setUp(self):
169169

170170
self.run_lora_file = "llm/finetune_generation.py"
171171

172-
def runfrist(self, train_args):
172+
def runfirst(self, train_args):
173173
self.run_n1c8(self.run_lora_file, **train_args)
174174

175175
def rerun(self, train_args):
@@ -181,7 +181,7 @@ def testTP4PP2(self):
181181
remove_ckpt(lora_arguments["output_dir"])
182182

183183
train_args = self.configs["TP4PP2"]
184-
self.runfrist(train_args)
184+
self.runfirst(train_args)
185185
self.rerun(train_args)
186186

187187
if self.need_allclose:
@@ -196,7 +196,7 @@ def testTP2Sharding4(self):
196196
remove_ckpt(lora_arguments["output_dir"])
197197

198198
train_args = self.configs["TP2Sharding4"]
199-
self.runfrist(train_args)
199+
self.runfirst(train_args)
200200
self.rerun(train_args)
201201

202202
if self.need_allclose:
@@ -213,7 +213,7 @@ def testTP8(self):
213213
remove_ckpt(lora_arguments["output_dir"])
214214

215215
train_args = self.configs["TP8"]
216-
self.runfrist(train_args)
216+
self.runfirst(train_args)
217217
self.rerun(train_args)
218218

219219
if self.need_allclose:
@@ -227,7 +227,7 @@ def testTP4DP2(self):
227227
remove_ckpt(lora_arguments["output_dir"])
228228

229229
train_args = self.configs["TP4DP2"]
230-
self.runfrist(train_args)
230+
self.runfirst(train_args)
231231
self.rerun(train_args)
232232

233233
if self.need_allclose:
@@ -242,7 +242,7 @@ def testTP4Sharding2(self):
242242
remove_ckpt(lora_arguments["output_dir"])
243243

244244
train_args = self.configs["TP4Sharding2"]
245-
self.runfrist(train_args)
245+
self.runfirst(train_args)
246246
self.rerun(train_args)
247247

248248
if self.need_allclose:
@@ -257,7 +257,7 @@ def testTP2PP4(self):
257257
remove_ckpt(lora_arguments["output_dir"])
258258

259259
train_args = self.configs["TP2PP4"]
260-
self.runfrist(train_args)
260+
self.runfirst(train_args)
261261
self.rerun(train_args)
262262

263263
if self.need_allclose:
@@ -272,7 +272,7 @@ def testPP8(self):
272272
remove_ckpt(lora_arguments["output_dir"])
273273

274274
train_args = self.configs["PP8"]
275-
self.runfrist(train_args)
275+
self.runfirst(train_args)
276276
self.rerun(train_args)
277277

278278
if self.need_allclose:
@@ -287,7 +287,7 @@ def testPP4DP2(self):
287287
remove_ckpt(lora_arguments["output_dir"])
288288

289289
train_args = self.configs["PP4DP2"]
290-
self.runfrist(train_args)
290+
self.runfirst(train_args)
291291
self.rerun(train_args)
292292

293293
if self.need_allclose:
@@ -302,7 +302,7 @@ def testPP4Sharding2(self):
302302
remove_ckpt(lora_arguments["output_dir"])
303303

304304
train_args = self.configs["PP4Sharding2"]
305-
self.runfrist(train_args)
305+
self.runfirst(train_args)
306306
self.rerun(train_args)
307307

308308
if self.need_allclose:
@@ -317,7 +317,7 @@ def testSharding8S1(self):
317317
remove_ckpt(lora_arguments["output_dir"])
318318

319319
train_args = self.configs["Sharding8S1"]
320-
self.runfrist(train_args)
320+
self.runfirst(train_args)
321321
self.rerun(train_args)
322322

323323
if self.need_allclose:
@@ -332,7 +332,7 @@ def testSharding8S2(self):
332332
remove_ckpt(lora_arguments["output_dir"])
333333

334334
train_args = self.configs["Sharding8S2"]
335-
self.runfrist(train_args)
335+
self.runfirst(train_args)
336336
self.rerun(train_args)
337337

338338
if self.need_allclose:
@@ -347,7 +347,7 @@ def testSharding4S1DP2(self):
347347
remove_ckpt(lora_arguments["output_dir"])
348348

349349
train_args = self.configs["Sharding4S1DP2"]
350-
self.runfrist(train_args)
350+
self.runfirst(train_args)
351351
self.rerun(train_args)
352352

353353
if self.need_allclose:
@@ -362,7 +362,7 @@ def testSharding4S2DP2(self):
362362
remove_ckpt(lora_arguments["output_dir"])
363363

364364
train_args = self.configs["Sharding4S2DP2"]
365-
self.runfrist(train_args)
365+
self.runfirst(train_args)
366366
self.rerun(train_args)
367367

368368
if self.need_allclose:
@@ -377,7 +377,7 @@ def testSharding2S1DP4(self):
377377
remove_ckpt(lora_arguments["output_dir"])
378378

379379
train_args = self.configs["Sharding2S1DP4"]
380-
self.runfrist(train_args)
380+
self.runfirst(train_args)
381381
self.rerun(train_args)
382382

383383
if self.need_allclose:
@@ -392,7 +392,7 @@ def testSharding2S2DP4(self):
392392
remove_ckpt(lora_arguments["output_dir"])
393393

394394
train_args = self.configs["Sharding2S2DP4"]
395-
self.runfrist(train_args)
395+
self.runfirst(train_args)
396396
self.rerun(train_args)
397397

398398
if self.need_allclose:
@@ -407,7 +407,7 @@ def testDP8(self):
407407
remove_ckpt(lora_arguments["output_dir"])
408408

409409
train_args = self.configs["DP8"]
410-
self.runfrist(train_args)
410+
self.runfirst(train_args)
411411
self.rerun(train_args)
412412

413413
if self.need_allclose:
@@ -416,27 +416,29 @@ def testDP8(self):
416416
np.testing.assert_allclose(res[0], res[1], self.rtol)
417417

418418

419+
@pytest.mark.skipif(True, reason="Skip for None CE")
419420
class TestUnifiedCheckpointOnN2C4(TestUnifiedCheckpointBase):
420421
def setUp(self):
421422
super().setUp()
422423
self.need_allclose = True
423424
self.rtol = 1e-7
424425

425-
def runfrist(self, train_args):
426+
def runfirst(self, train_args):
426427
self.run_n2c4(self.run_lora_file, **train_args)
427428

428429
def rerun(self, train_args):
429430
self.run_n2c4(self.run_lora_file, **train_args)
430431

431432

433+
@pytest.mark.skipif(True, reason="Skip for None CE")
432434
class TestUnifiedCheckpointOnN1C8CheckpointCompatible(TestUnifiedCheckpointBase):
433435
def setUp(self):
434436
super().setUp()
435437

436438
self.need_allclose = True
437439
self.rtol = 1e-7
438440

439-
def runfrist(self, train_args):
441+
def runfirst(self, train_args):
440442
train_args["unified_checkpoint"] = 0
441443
self.run_n1c8(self.run_lora_file, **train_args)
442444

@@ -445,14 +447,15 @@ def rerun(self, train_args):
445447
self.run_n1c8(self.run_lora_file, **train_args)
446448

447449

450+
@pytest.mark.skipif(True, reason="Skip for None CE")
448451
class TestPaddleCheckpointOnN1C8Reset(TestUnifiedCheckpointBase):
449452
def setUp(self):
450453
super().setUp()
451454

452455
self.need_allclose = True
453456
self.rtol = 1e-7
454457

455-
def runfrist(self, train_args):
458+
def runfirst(self, train_args):
456459
train_args["unified_checkpoint"] = 0
457460
self.run_n1c8(self.run_lora_file, **train_args)
458461

@@ -469,7 +472,7 @@ def setUp(self):
469472
self.need_allclose = True
470473
self.rtol = 1e-7
471474

472-
def runfrist(self, train_args):
475+
def runfirst(self, train_args):
473476
train_args["unified_checkpoint"] = 0
474477
self.run_n2c4(self.run_lora_file, **train_args)
475478

0 commit comments

Comments
 (0)