1919from  typing  import  List , Union 
2020
2121import  paddle 
22- from  paddle .distributed .checkpoint .load_state_dict  import  (
22+ from  paddle .distributed .fleet .utils .log_util  import  logger 
23+ from  paddle .distributed .flex_checkpoint .dcp .load_state_dict  import  (
2324    _load_state_dict ,
2425    get_rank_to_read_files ,
2526)
26- from  paddle .distributed .checkpoint .metadata  import  (
27+ from  paddle .distributed .flex_checkpoint . dcp .metadata  import  (
2728    LocalTensorIndex ,
2829    LocalTensorMetadata ,
2930    Metadata ,
3031)
31- from  paddle .distributed .checkpoint .utils  import  flatten_state_dict 
32- from  paddle .distributed .fleet .utils .log_util  import  logger 
32+ from  paddle .distributed .flex_checkpoint .dcp .utils  import  flatten_state_dict 
3333
3434MODEL_WEIGHT_SUFFIX  =  ".pdparams" 
3535OPTIMIZER_WEIGHT_SUFFIX  =  ".pdopt" 
@@ -206,7 +206,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
206206                global_offset  =  [0 ] *  self .tp_degree 
207207                for  item  in  shard_info :
208208                    tp_rank  =  item [0 ]["tp_rank" ]
209-                     state_name_with_tp_rank  =  state_name  +  "_tp"  +  "{ :02d}". format ( tp_rank ) 
209+                     state_name_with_tp_rank  =  state_name  +  "_tp"  +  f" { tp_rank :02d} 
210210                    local_tensor_meta_data  =  LocalTensorMetadata ((global_offset [tp_rank ],), item [1 ], item [2 ])
211211                    local_tensor_index  =  LocalTensorIndex (state_name_with_tp_rank , (global_offset [tp_rank ],))
212212                    global_offset [tp_rank ] +=  item [1 ][0 ]
@@ -225,7 +225,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
225225                renamed_state_dict  =  {}
226226                (tp_rank , pp_rank , sharding_rank ) =  self .get_distribution_rank_from_file_name (file_name )
227227                for  state_name , state_value  in  state_dict .items ():
228-                     state_name_with_tp_rank  =  state_name  +  "_tp"  +  "{ :02d}". format ( tp_rank ) 
228+                     state_name_with_tp_rank  =  state_name  +  "_tp"  +  f" { tp_rank :02d} 
229229                    renamed_state_dict [state_name_with_tp_rank ] =  state_value 
230230
231231                source_state_dict_for_merge_sharding [file_name ] =  renamed_state_dict 
@@ -235,7 +235,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
235235            sharding_metas_keys  =  []
236236            for  i  in  range (self .tp_degree ):
237237                for  j  in  range (self .pp_degree ):
238-                     sharding_metas_keys .append ("tp{:02d}_pp{:02d}" . format ( i ,  j ) )
238+                     sharding_metas_keys .append (f "tp{ i :02d} { j :02d} 
239239            for  key  in  sharding_metas_keys :
240240                param_meta  =  self .model_meta ["sharding_metas" ][key ]["param_meta" ]
241241                for  param_name , param_shape_and_dtype  in  param_meta .items ():
@@ -253,7 +253,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
253253            all_param_meta  =  {}
254254            for  i  in  range (self .tp_degree ):
255255                for  j  in  range (self .pp_degree ):
256-                     key  =  "tp{:02d}_pp{:02d}" . format ( i ,  j ) 
256+                     key  =  f "tp{ i :02d} { j :02d} 
257257                    param_meta  =  self .model_meta ["sharding_metas" ][key ]["param_meta" ]
258258                    for  param_name , param_shape_and_dtype  in  param_meta .items ():
259259                        all_param_meta [param_name ] =  param_shape_and_dtype 
@@ -269,7 +269,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
269269            with  paddle .base .dygraph .guard (place = paddle .CPUPlace ()):
270270                for  key  in  cur_rank_need_load_model_state_keys :
271271                    for  tp_rank  in  range (self .tp_degree ):
272-                         tp_rank_suffix  =  "_tp{:02d}" . format ( tp_rank ) 
272+                         tp_rank_suffix  =  f "_tp{ tp_rank :02d} 
273273                        optimizer_state_dict [key  +  ".moment1"  +  tp_rank_suffix ] =  paddle .zeros (
274274                            (param_flattened_shapes [key ],), "float32" 
275275                        )
@@ -353,7 +353,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
353353                else :
354354                    concat_optimier_state_dict [opt_state_name_removed_tp_rank ] =  tp_tensors [0 ]
355355
356-             fake_file_name  =  "{ :02d}". format ( self . cur_rank )  +  ".distcp" 
356+             fake_file_name  =  f" { self . cur_rank :02d} +  ".distcp" 
357357            local_tensor_meta_data  =  {}
358358            local_tensor_index  =  {}
359359            for  k , v  in  concat_optimier_state_dict .items ():
@@ -472,7 +472,7 @@ def gen_metadata_and_prepare_source_state_dict(self):
472472                                reshaped_v  =  v .reshape (shape )
473473                                target_state_dict [k ] =  reshaped_v 
474474
475-                 fake_file_name  =  "{ :02d}". format ( self . cur_rank )  +  ".distcp" 
475+                 fake_file_name  =  f" { self . cur_rank :02d} +  ".distcp" 
476476                local_tensor_meta_data  =  {}
477477                local_tensor_index  =  {}
478478                for  k , v  in  target_state_dict .items ():
@@ -911,7 +911,7 @@ def rename_using_model_meta(self, file_name):
911911                self .model_meta  =  json .load (file )
912912
913913        (tp_rank , pp_rank , sharding_rank ) =  self .get_distribution_rank_from_file_name (file_name )
914-         dist_strategy_key  =  "tp"  +  "{ :02d}". format ( tp_rank )  +  "_"  +  "pp"  +  "{ :02d}". format ( pp_rank ) 
914+         dist_strategy_key  =  "tp"  +  f" { tp_rank :02d} +  "_"  +  "pp"  +  f" { pp_rank :02d} 
915915        # Map model weight names to their corresponding names of master_weights in the optimizer state. 
916916        if  file_name .endswith (OPTIMIZER_WEIGHT_SUFFIX ):
917917            structure_name_mapping  =  self .model_meta ["sharding_metas" ][dist_strategy_key ]["structure_name_mapping" ]
0 commit comments