Skip to content

Commit

Permalink
updating gptj-config (#109)
Browse files Browse the repository at this point in the history
* updating gptj-config

* added distributed config logging to wandb

* update

* black fix

* adding check for ds_plugin

* removing wandb entity name from default config

Co-authored-by: Louis Castricato <[email protected]>
  • Loading branch information
Dahoas and LouisCastricato authored Dec 4, 2022
1 parent b60f05e commit b229288
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 5 deletions.
4 changes: 2 additions & 2 deletions configs/ppo_gptj.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ method:
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 0.2 # value term weight
scale_reward: False # False | "ref" | "running" estimate against which to scale rewards
scale_reward: False # False | "ref" | "running" estimate against which to scale rewards
ref_mean: null
ref_std: null # rescale rewards with this deviation
ref_std: null # rescale rewards with this deviation
cliprange_reward: 10
gen_kwargs:
max_length: 48 # LM max sample gen length
Expand Down
1 change: 0 additions & 1 deletion examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def reward_fn(samples: List[str]) -> List[float]:
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]

model = trlx.train(
"lvwerra/gpt2-imdb",
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=["I don't know much about Hungarian underground"] * 64,
Expand Down
7 changes: 5 additions & 2 deletions trlx/model/accelerate_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import ray
from ray.air import session
from ray.air.checkpoint import Checkpoint
from trlx.utils import filter_non_scalars, get_git_tag
from trlx.utils import filter_non_scalars, get_distributed_config, get_git_tag


@register_model
Expand Down Expand Up @@ -76,9 +76,12 @@ def __init__(self, config, train_mode=True):
run_name = f"{script_name}/{model_name}"

if self.accelerator.is_main_process and not ray.is_initialized():
config_dict = self.config.to_dict()
dist_config = get_distributed_config(self.accelerator)
config_dict["distributed"] = dist_config
self.accelerator.init_trackers(
project_name=self.config.train.project_name,
config=self.config.to_dict(),
config=config_dict,
init_kwargs={
"wandb": {
"name": run_name,
Expand Down
29 changes: 29 additions & 0 deletions trlx/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from torch.optim.lr_scheduler import ChainedScheduler, LinearLR
from torchtyping import TensorType

import accelerate
from accelerate import Accelerator


def set_seed(seed: int):
"""
Expand Down Expand Up @@ -57,6 +60,32 @@ def safe_mkdir(path: str):
os.mkdir(path)


def get_distributed_config(accelerator: Accelerator):
"""
Return accelerator distributed config
"""

accelerate_config = accelerator.state
dist_config = {
"mixed_precision": accelerate_config.mixed_precision,
"num_gpus": accelerate_config.num_processes,
}

if hasattr(accelerator.state, "deepspeed_plugin"):
ds_plugin = accelerator.state.deepspeed_plugin
dist_config.upate(
{
"gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps,
"gradient_clipping": ds_plugin.gradient_clipping,
"zero_stage": ds_plugin.zero_stage,
"offload_optimizer_device": ds_plugin.offload_optimizer_device,
"offload_param_device": ds_plugin.offload_param_device,
}
)

return dist_config


# Stats


Expand Down

0 comments on commit b229288

Please sign in to comment.