-
Notifications
You must be signed in to change notification settings - Fork 472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor: remove orchestrator abstraction from API #289
Conversation
Thanks jon! I think you might have to also change something for the nemo trainer? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this mostly looks good except for the nemo, in future we can try to refactor and break up the big make_experience functions
@@ -89,8 +89,12 @@ def train( # noqa: C901 | |||
eval_prompts = prompts[:batch_size] | |||
|
|||
pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer) | |||
orch = get_orchestrator(config.train.orchestrator)(trainer, pipeline, chunk_size=config.method.chunk_size) | |||
orch.make_experience(config.method.num_rollouts) | |||
trainer.add_prompt_pipeline(pipeline) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re add_prompt_pipeline
- yeah that is a bit awkward, it should be possible to pass it via args if you move the get_trainer
call into the PPO part of the branch? But if its too messy then np
You could also probably replace the get_pipeline
with PromptPipeline
since all the models use the same pipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Yeah I wasn't sure if it'd be too messy - let me give it a go :)
-
I agree on replacing the
get_pipeline
stuff with justPromptPipeline
; I originally did that but reverted before creating the PR to limit the scope to the issue being addressed. I can't think of any other prompt pipelines so not sure what the_DATAPIPELINE
registry is even intended for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so not sure what the _DATAPIPELINE registry is even intended for?
It's from the time when every pipeline was ought to be a specific dataset, each registered deliberately
@@ -95,3 +142,73 @@ def save_pretrained(self, directory: Optional[str] = None): | |||
"`AccelerateILQLTrainer` does not currently support automatic saving " | |||
"with `transformers.PreTrainedModel.save_pretrained`." | |||
) | |||
|
|||
def make_experience(self, samples, rewards, max_length=2048): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe pull this into somewhere it can be shared with nemo impl?
I guess this could mean its worth also passing in the rollout store as an arg like the promptpipeline for PPO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've duplicated the make_experience
into NeMo for now. There was a subtle difference in logging whereby NeMo couldn't recognize the global rank RANK == 0
checks forcing each rank to write tables to stdout (the fix is to just use their global rank check util).
I think it might be best to push this off to another PR because we'll need to re-visit this abstraction again for PPO. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that makes sense for now. Maybe torch.distributed.get_rank()
will work for both but we can revisit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks jon! I think this looks good to me, we can excise and clean up get_pipeline in future work and revisit the sharing between accelerate and nemo
This PR removes all of the orchestrator components for reasons outlined in #278.
Highlights for reviewer(s):
Adds a new method to
AcceleratePPOTrainer
calledadd_prompt_pipeline
that mimics the prompt pipeline loading and device placement of the removed PPO orchestrator. This is sort of awkward because it requires users to manually call the method before runningmake_experience
(the same thing you have to do withadd_eval_pipeline
). Open to suggestions; I'd prefer to pass the pipeline to the trainer constructor but it breaks the config-first approach currently implemented (should discuss for future refactoring?).Removes dead code related to MagiCARP. Removes unused
utils.topk_mask
.Reproduction reports: