@@ -149,7 +149,7 @@ def __test__(cls):
149
149
150
150
def setUp (self ):
151
151
"""
152
- 1. update runfrist and rerun to run defined different config
152
+ 1. update runfirst and rerun to run defined different config
153
153
2. update need_allclose to True if you want to check the result
154
154
3. update rtol to the relative value you want to check
155
155
"""
@@ -169,7 +169,7 @@ def setUp(self):
169
169
170
170
self .run_lora_file = "llm/finetune_generation.py"
171
171
172
- def runfrist (self , train_args ):
172
+ def runfirst (self , train_args ):
173
173
self .run_n1c8 (self .run_lora_file , ** train_args )
174
174
175
175
def rerun (self , train_args ):
@@ -181,7 +181,7 @@ def testTP4PP2(self):
181
181
remove_ckpt (lora_arguments ["output_dir" ])
182
182
183
183
train_args = self .configs ["TP4PP2" ]
184
- self .runfrist (train_args )
184
+ self .runfirst (train_args )
185
185
self .rerun (train_args )
186
186
187
187
if self .need_allclose :
@@ -196,7 +196,7 @@ def testTP2Sharding4(self):
196
196
remove_ckpt (lora_arguments ["output_dir" ])
197
197
198
198
train_args = self .configs ["TP2Sharding4" ]
199
- self .runfrist (train_args )
199
+ self .runfirst (train_args )
200
200
self .rerun (train_args )
201
201
202
202
if self .need_allclose :
@@ -213,7 +213,7 @@ def testTP8(self):
213
213
remove_ckpt (lora_arguments ["output_dir" ])
214
214
215
215
train_args = self .configs ["TP8" ]
216
- self .runfrist (train_args )
216
+ self .runfirst (train_args )
217
217
self .rerun (train_args )
218
218
219
219
if self .need_allclose :
@@ -227,7 +227,7 @@ def testTP4DP2(self):
227
227
remove_ckpt (lora_arguments ["output_dir" ])
228
228
229
229
train_args = self .configs ["TP4DP2" ]
230
- self .runfrist (train_args )
230
+ self .runfirst (train_args )
231
231
self .rerun (train_args )
232
232
233
233
if self .need_allclose :
@@ -242,7 +242,7 @@ def testTP4Sharding2(self):
242
242
remove_ckpt (lora_arguments ["output_dir" ])
243
243
244
244
train_args = self .configs ["TP4Sharding2" ]
245
- self .runfrist (train_args )
245
+ self .runfirst (train_args )
246
246
self .rerun (train_args )
247
247
248
248
if self .need_allclose :
@@ -257,7 +257,7 @@ def testTP2PP4(self):
257
257
remove_ckpt (lora_arguments ["output_dir" ])
258
258
259
259
train_args = self .configs ["TP2PP4" ]
260
- self .runfrist (train_args )
260
+ self .runfirst (train_args )
261
261
self .rerun (train_args )
262
262
263
263
if self .need_allclose :
@@ -272,7 +272,7 @@ def testPP8(self):
272
272
remove_ckpt (lora_arguments ["output_dir" ])
273
273
274
274
train_args = self .configs ["PP8" ]
275
- self .runfrist (train_args )
275
+ self .runfirst (train_args )
276
276
self .rerun (train_args )
277
277
278
278
if self .need_allclose :
@@ -287,7 +287,7 @@ def testPP4DP2(self):
287
287
remove_ckpt (lora_arguments ["output_dir" ])
288
288
289
289
train_args = self .configs ["PP4DP2" ]
290
- self .runfrist (train_args )
290
+ self .runfirst (train_args )
291
291
self .rerun (train_args )
292
292
293
293
if self .need_allclose :
@@ -302,7 +302,7 @@ def testPP4Sharding2(self):
302
302
remove_ckpt (lora_arguments ["output_dir" ])
303
303
304
304
train_args = self .configs ["PP4Sharding2" ]
305
- self .runfrist (train_args )
305
+ self .runfirst (train_args )
306
306
self .rerun (train_args )
307
307
308
308
if self .need_allclose :
@@ -317,7 +317,7 @@ def testSharding8S1(self):
317
317
remove_ckpt (lora_arguments ["output_dir" ])
318
318
319
319
train_args = self .configs ["Sharding8S1" ]
320
- self .runfrist (train_args )
320
+ self .runfirst (train_args )
321
321
self .rerun (train_args )
322
322
323
323
if self .need_allclose :
@@ -332,7 +332,7 @@ def testSharding8S2(self):
332
332
remove_ckpt (lora_arguments ["output_dir" ])
333
333
334
334
train_args = self .configs ["Sharding8S2" ]
335
- self .runfrist (train_args )
335
+ self .runfirst (train_args )
336
336
self .rerun (train_args )
337
337
338
338
if self .need_allclose :
@@ -347,7 +347,7 @@ def testSharding4S1DP2(self):
347
347
remove_ckpt (lora_arguments ["output_dir" ])
348
348
349
349
train_args = self .configs ["Sharding4S1DP2" ]
350
- self .runfrist (train_args )
350
+ self .runfirst (train_args )
351
351
self .rerun (train_args )
352
352
353
353
if self .need_allclose :
@@ -362,7 +362,7 @@ def testSharding4S2DP2(self):
362
362
remove_ckpt (lora_arguments ["output_dir" ])
363
363
364
364
train_args = self .configs ["Sharding4S2DP2" ]
365
- self .runfrist (train_args )
365
+ self .runfirst (train_args )
366
366
self .rerun (train_args )
367
367
368
368
if self .need_allclose :
@@ -377,7 +377,7 @@ def testSharding2S1DP4(self):
377
377
remove_ckpt (lora_arguments ["output_dir" ])
378
378
379
379
train_args = self .configs ["Sharding2S1DP4" ]
380
- self .runfrist (train_args )
380
+ self .runfirst (train_args )
381
381
self .rerun (train_args )
382
382
383
383
if self .need_allclose :
@@ -392,7 +392,7 @@ def testSharding2S2DP4(self):
392
392
remove_ckpt (lora_arguments ["output_dir" ])
393
393
394
394
train_args = self .configs ["Sharding2S2DP4" ]
395
- self .runfrist (train_args )
395
+ self .runfirst (train_args )
396
396
self .rerun (train_args )
397
397
398
398
if self .need_allclose :
@@ -407,7 +407,7 @@ def testDP8(self):
407
407
remove_ckpt (lora_arguments ["output_dir" ])
408
408
409
409
train_args = self .configs ["DP8" ]
410
- self .runfrist (train_args )
410
+ self .runfirst (train_args )
411
411
self .rerun (train_args )
412
412
413
413
if self .need_allclose :
@@ -416,27 +416,29 @@ def testDP8(self):
416
416
np .testing .assert_allclose (res [0 ], res [1 ], self .rtol )
417
417
418
418
419
+ @pytest .mark .skipif (True , reason = "Skip for None CE" )
419
420
class TestUnifiedCheckpointOnN2C4 (TestUnifiedCheckpointBase ):
420
421
def setUp (self ):
421
422
super ().setUp ()
422
423
self .need_allclose = True
423
424
self .rtol = 1e-7
424
425
425
- def runfrist (self , train_args ):
426
+ def runfirst (self , train_args ):
426
427
self .run_n2c4 (self .run_lora_file , ** train_args )
427
428
428
429
def rerun (self , train_args ):
429
430
self .run_n2c4 (self .run_lora_file , ** train_args )
430
431
431
432
433
+ @pytest .mark .skipif (True , reason = "Skip for None CE" )
432
434
class TestUnifiedCheckpointOnN1C8CheckpointCompatible (TestUnifiedCheckpointBase ):
433
435
def setUp (self ):
434
436
super ().setUp ()
435
437
436
438
self .need_allclose = True
437
439
self .rtol = 1e-7
438
440
439
- def runfrist (self , train_args ):
441
+ def runfirst (self , train_args ):
440
442
train_args ["unified_checkpoint" ] = 0
441
443
self .run_n1c8 (self .run_lora_file , ** train_args )
442
444
@@ -445,14 +447,15 @@ def rerun(self, train_args):
445
447
self .run_n1c8 (self .run_lora_file , ** train_args )
446
448
447
449
450
+ @pytest .mark .skipif (True , reason = "Skip for None CE" )
448
451
class TestPaddleCheckpointOnN1C8Reset (TestUnifiedCheckpointBase ):
449
452
def setUp (self ):
450
453
super ().setUp ()
451
454
452
455
self .need_allclose = True
453
456
self .rtol = 1e-7
454
457
455
- def runfrist (self , train_args ):
458
+ def runfirst (self , train_args ):
456
459
train_args ["unified_checkpoint" ] = 0
457
460
self .run_n1c8 (self .run_lora_file , ** train_args )
458
461
@@ -469,7 +472,7 @@ def setUp(self):
469
472
self .need_allclose = True
470
473
self .rtol = 1e-7
471
474
472
- def runfrist (self , train_args ):
475
+ def runfirst (self , train_args ):
473
476
train_args ["unified_checkpoint" ] = 0
474
477
self .run_n2c4 (self .run_lora_file , ** train_args )
475
478
0 commit comments