From 630c95c5d84e6756497dac5a1b25b944a0498c0d Mon Sep 17 00:00:00 2001 From: Dahoas Date: Tue, 22 Nov 2022 23:38:37 +0000 Subject: [PATCH 1/6] updating gptj-config --- configs/ppo_gptj.yml | 4 ++++ examples/ppo_sentiments.py | 3 +-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/configs/ppo_gptj.yml b/configs/ppo_gptj.yml index ebd92b3a8..52c438587 100644 --- a/configs/ppo_gptj.yml +++ b/configs/ppo_gptj.yml @@ -35,6 +35,10 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 0.2 # value term weight + scale_reward: False # False | "ref" | "running" estimate against which to scale rewards + ref_mean: null + ref_std: null # rescale rewards with this deviation + cliprange_reward: 10 gen_kwargs: max_length: 48 # LM max sample gen length min_length: 48 # LM min sample gen length diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index a596a10e7..4638b5b7e 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -17,7 +17,7 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -default_config = yaml.safe_load(open("configs/ppo_config.yml")) +default_config = yaml.safe_load(open("configs/ppo_gptj.yml")) def main(hparams={}): @@ -46,7 +46,6 @@ def reward_fn(samples: List[str]) -> List[float]: prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] model = trlx.train( - "lvwerra/gpt2-imdb", reward_fn=reward_fn, prompts=prompts, eval_prompts=["I don't know much about Hungarian underground"] * 64, From 39990b65156aeb69ea0f6eecb5b0c5f59a0223f0 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 23 Nov 2022 05:26:37 +0000 Subject: [PATCH 2/6] added distributed config logging to wandb --- configs/ppo_config.yml | 2 +- examples/ppo_sentiments.py | 2 +- trlx/model/accelerate_base_model.py | 7 +++++-- trlx/utils/__init__.py | 20 ++++++++++++++++++++ trlx/utils/loading.py | 2 +- 5 files changed, 28 insertions(+), 5 deletions(-) diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index e32409490..2156b93e8 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -21,7 +21,7 @@ train: pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load - entity_name: "jon-tow" + entity_name: "dahoas" # put your wandb login here method: name: 'ppoconfig' # Name of RL method config diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 4638b5b7e..18340b1e2 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -17,7 +17,7 @@ def get_positive_score(scores): return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] -default_config = yaml.safe_load(open("configs/ppo_gptj.yml")) +default_config = yaml.safe_load(open("configs/ppo_config.yml")) def main(hparams={}): diff --git a/trlx/model/accelerate_base_model.py b/trlx/model/accelerate_base_model.py index bdbb31bd4..ebf85b76b 100644 --- a/trlx/model/accelerate_base_model.py +++ b/trlx/model/accelerate_base_model.py @@ -23,7 +23,7 @@ import ray from ray.air import session from ray.air.checkpoint import Checkpoint -from trlx.utils import filter_non_scalars +from trlx.utils import filter_non_scalars, get_distributed_config @register_model @@ -78,9 +78,12 @@ def __init__(self, config, train_mode=True): run_name = f"{script_name}/{model_name}" if self.accelerator.is_main_process and not ray.is_initialized(): + config_dict = self.config.to_dict() + dist_config = get_distributed_config(self.accelerator) + config_dict["distributed"] = dist_config self.accelerator.init_trackers( project_name=self.config.train.project_name, - config=self.config.to_dict(), + config=config_dict, init_kwargs={ "wandb": { "name": run_name, diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 335d7887d..cdbd5cefe 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -9,6 +9,9 @@ from torch.optim.lr_scheduler import ChainedScheduler, LinearLR from torchtyping import TensorType +import accelerate +from accelerate import Accelerator + def flatten(L: Iterable[Iterable[Any]]) -> Iterable[Any]: """ @@ -45,6 +48,23 @@ def safe_mkdir(path: str): os.mkdir(path) +def get_distributed_config(accelerator: Accelerator): + """ + Return accelerator distributed config + """ + ds_plugin = accelerator.state.deepspeed_plugin + accelerate_config = accelerator.state + return { + "mixed_precision": accelerate_config.mixed_precision, + "num_gpus": accelerate_config.num_processes, + "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, + "gradient_clipping": ds_plugin.gradient_clipping, + "zero_stage": ds_plugin.zero_stage, + "offload_optimizer_device": ds_plugin.offload_optimizer_device, + "offload_param_device": ds_plugin.offload_param_device, + } + + # Stats diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index 4b603c80f..47227c904 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -49,4 +49,4 @@ def get_orchestrator(name: str) -> Callable: else: raise Exception( "Error: Trying to access an orchestrator that has not been registered" - ) + ) \ No newline at end of file From 4b2c5320f6a164ae57c1ff8579b75523f20c9054 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 23 Nov 2022 06:00:03 +0000 Subject: [PATCH 3/6] update --- trlx/utils/loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index 47227c904..4b603c80f 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -49,4 +49,4 @@ def get_orchestrator(name: str) -> Callable: else: raise Exception( "Error: Trying to access an orchestrator that has not been registered" - ) \ No newline at end of file + ) From 25399c25abc137e3db3328befa7395114b6b5b32 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 23 Nov 2022 06:31:31 +0000 Subject: [PATCH 4/6] black fix --- trlx/model/accelerate_base_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trlx/model/accelerate_base_model.py b/trlx/model/accelerate_base_model.py index 906bdeaaa..4000d4b4f 100644 --- a/trlx/model/accelerate_base_model.py +++ b/trlx/model/accelerate_base_model.py @@ -26,7 +26,6 @@ from trlx.utils import filter_non_scalars, get_distributed_config, get_git_tag - @register_model class AccelerateRLModel(BaseRLModel): """ From fb701e759ae9dac98061b8182fb494530f841855 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 23 Nov 2022 15:36:04 +0000 Subject: [PATCH 5/6] adding check for ds_plugin --- trlx/utils/__init__.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 9982a4387..c5bc7be55 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -53,18 +53,27 @@ def get_distributed_config(accelerator: Accelerator): """ Return accelerator distributed config """ - ds_plugin = accelerator.state.deepspeed_plugin + accelerate_config = accelerator.state - return { + dist_config = { "mixed_precision": accelerate_config.mixed_precision, "num_gpus": accelerate_config.num_processes, - "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, - "gradient_clipping": ds_plugin.gradient_clipping, - "zero_stage": ds_plugin.zero_stage, - "offload_optimizer_device": ds_plugin.offload_optimizer_device, - "offload_param_device": ds_plugin.offload_param_device, } + if hasattr(accelerator.state, "deepspeed_plugin"): + ds_plugin = accelerator.state.deepspeed_plugin + dist_config.upate( + { + "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, + "gradient_clipping": ds_plugin.gradient_clipping, + "zero_stage": ds_plugin.zero_stage, + "offload_optimizer_device": ds_plugin.offload_optimizer_device, + "offload_param_device": ds_plugin.offload_param_device, + } + ) + + return dist_config + # Stats From 65ee9c01069402a11817d5ca793cc4a8a7b59339 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 23 Nov 2022 15:37:30 +0000 Subject: [PATCH 6/6] removing wandb entity name from default config --- configs/ppo_config.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index 2156b93e8..24425e62b 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -21,7 +21,6 @@ train: pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load - entity_name: "dahoas" # put your wandb login here method: name: 'ppoconfig' # Name of RL method config