diff --git a/configs/ppo_gptj.yml b/configs/ppo_gptj.yml index 2c5db1b8f..52c438587 100644 --- a/configs/ppo_gptj.yml +++ b/configs/ppo_gptj.yml @@ -35,9 +35,9 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 0.2 # value term weight - scale_reward: False # False | "ref" | "running" estimate against which to scale rewards + scale_reward: False # False | "ref" | "running" estimate against which to scale rewards ref_mean: null - ref_std: null # rescale rewards with this deviation + ref_std: null # rescale rewards with this deviation cliprange_reward: 10 gen_kwargs: max_length: 48 # LM max sample gen length diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index a596a10e7..18340b1e2 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -46,7 +46,6 @@ def reward_fn(samples: List[str]) -> List[float]: prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] model = trlx.train( - "lvwerra/gpt2-imdb", reward_fn=reward_fn, prompts=prompts, eval_prompts=["I don't know much about Hungarian underground"] * 64, diff --git a/trlx/model/accelerate_base_model.py b/trlx/model/accelerate_base_model.py index 14e46d9fe..4f0d883a6 100644 --- a/trlx/model/accelerate_base_model.py +++ b/trlx/model/accelerate_base_model.py @@ -23,7 +23,7 @@ import ray from ray.air import session from ray.air.checkpoint import Checkpoint -from trlx.utils import filter_non_scalars, get_git_tag +from trlx.utils import filter_non_scalars, get_distributed_config, get_git_tag @register_model @@ -76,9 +76,12 @@ def __init__(self, config, train_mode=True): run_name = f"{script_name}/{model_name}" if self.accelerator.is_main_process and not ray.is_initialized(): + config_dict = self.config.to_dict() + dist_config = get_distributed_config(self.accelerator) + config_dict["distributed"] = dist_config self.accelerator.init_trackers( project_name=self.config.train.project_name, - config=self.config.to_dict(), + config=config_dict, init_kwargs={ "wandb": { "name": run_name, diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 3545e8e6f..6480a43fd 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -11,6 +11,9 @@ from torch.optim.lr_scheduler import ChainedScheduler, LinearLR from torchtyping import TensorType +import accelerate +from accelerate import Accelerator + def set_seed(seed: int): """ @@ -57,6 +60,32 @@ def safe_mkdir(path: str): os.mkdir(path) +def get_distributed_config(accelerator: Accelerator): + """ + Return accelerator distributed config + """ + + accelerate_config = accelerator.state + dist_config = { + "mixed_precision": accelerate_config.mixed_precision, + "num_gpus": accelerate_config.num_processes, + } + + if hasattr(accelerator.state, "deepspeed_plugin"): + ds_plugin = accelerator.state.deepspeed_plugin + dist_config.upate( + { + "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, + "gradient_clipping": ds_plugin.gradient_clipping, + "zero_stage": ds_plugin.zero_stage, + "offload_optimizer_device": ds_plugin.offload_optimizer_device, + "offload_param_device": ds_plugin.offload_param_device, + } + ) + + return dist_config + + # Stats