-
Notifications
You must be signed in to change notification settings - Fork 473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
updating gptj-config #109
updating gptj-config #109
Conversation
examples/ppo_sentiments.py
Outdated
@@ -17,7 +17,7 @@ def get_positive_score(scores): | |||
return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] | |||
|
|||
|
|||
default_config = yaml.safe_load(open("configs/ppo_config.yml")) | |||
default_config = yaml.safe_load(open("configs/ppo_gptj.yml")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you sure you want to change the default to gpt-j? I wouldn't mind the second script which imports main
from here and changes config as the most simplest option right now, until we switch to hydra or something else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, this is a bad idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I just forgot to change this back.
Added distributed config logging to wandb |
configs/ppo_config.yml
Outdated
@@ -21,6 +21,7 @@ train: | |||
|
|||
pipeline: "PromptPipeline" # prompt pipeline to load | |||
orchestrator: "PPOOrchestrator" # orchestrator to load | |||
entity_name: "dahoas" # put your wandb login here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will give somewhat cryptic error if you're not logged in as dahoas, maybe we can make this as optional enviroment variable instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets fix this in a future PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have it automatically determine what W&B account is logged in. cc @ayulockin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it already does so, this option is optional and was added in #78
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise it forces people to wandb disable
as in #106
trlx/utils/__init__.py
Outdated
return { | ||
"mixed_precision": accelerate_config.mixed_precision, | ||
"num_gpus": accelerate_config.num_processes, | ||
"gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if ds is not used then ds_plugin
is None, giving attribute error here
https://github.com/huggingface/accelerate/blob/e4e5611e5d4270a846caf42cba3388e54b83f074/src/accelerate/state.py#L62
maybe some processing of repr(accelerator.state)
(which is outputed in accelerate env
) would be equivalent here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's true in some cases a user may not be using deepspeed so a check should be performed. However with the accelerator.state object there are some items I don't care about (such as the local rank).
Updating the gptj config and verifying gptj runs on 8 A100s with zero2
https://wandb.ai/dahoas/trlx/runs/5vc7xsx8?workspace=user-dahoas