-
Notifications
You must be signed in to change notification settings - Fork 315
[train][2/N] Support for Megatron PP + CP for R3 #1327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
a802773
previous code
erictang000 21b44de
lint
erictang000 23fdc45
make opus take a pass at test + plumbing fully thru generator
erictang000 8daac59
updated test utils and file to support rollout replay indices
devpatelio 647426f
add helper functions for router visibility and megatron testing, succ…
devpatelio d4b753f
linter
devpatelio f1b9c53
worked w opus to get forward pass logprob diff lower with replay + ru…
erictang000 8a8fa70
add test for forward backward and fix behavior
erictang000 410995a
working for qwen but not moonlight... debugging moonlight
erictang000 93eee65
x
erictang000 9c716a1
fixed test for moonlight by enforcing fused attn
devpatelio 097d2ad
Merge branch 'r3' of https://github.com/NovaSky-AI/SkyRL into HEAD
devpatelio 6de7d5c
x
devpatelio 591af9b
x
devpatelio acb35ec
clean up
erictang000 5ad9426
rename var and clean up
erictang000 7367359
testing replay utils with pp
devpatelio b88b820
move rank up
devpatelio 4bbf22b
working CP and PP implementation
erictang000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -135,6 +135,7 @@ def _postprocess_outputs(self, outputs): | |||||||||||||||||||||||||||||||
| stop_reasons: List[str] = [] | ||||||||||||||||||||||||||||||||
| response_ids: List[List[int]] = [] | ||||||||||||||||||||||||||||||||
| response_logprobs: Optional[List[List[float]]] = [] | ||||||||||||||||||||||||||||||||
| rollout_expert_indices: Optional[List[List[List[List[int]]]]] = [] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| for output in outputs: | ||||||||||||||||||||||||||||||||
| # TODO(tgriggs): Support n>1 sampling. | ||||||||||||||||||||||||||||||||
|
|
@@ -156,14 +157,26 @@ def _postprocess_outputs(self, outputs): | |||||||||||||||||||||||||||||||
| del token_logprobs | ||||||||||||||||||||||||||||||||
| response_logprobs.append(_logprobs) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| _routed_experts = None | ||||||||||||||||||||||||||||||||
| if resp.routed_experts is not None: | ||||||||||||||||||||||||||||||||
| if hasattr(resp.routed_experts, "tolist"): | ||||||||||||||||||||||||||||||||
| _routed_experts = resp.routed_experts.tolist() | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| _routed_experts = resp.routed_experts | ||||||||||||||||||||||||||||||||
| rollout_expert_indices.append(_routed_experts) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if len(response_logprobs) and response_logprobs[0] is None: | ||||||||||||||||||||||||||||||||
| response_logprobs = None # hack: assume uniform sampling params | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None: | ||||||||||||||||||||||||||||||||
| rollout_expert_indices = None # hack: assume uniform sampling params | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return InferenceEngineOutput( | ||||||||||||||||||||||||||||||||
| responses=responses, | ||||||||||||||||||||||||||||||||
| stop_reasons=stop_reasons, | ||||||||||||||||||||||||||||||||
| response_ids=response_ids, | ||||||||||||||||||||||||||||||||
| response_logprobs=response_logprobs, | ||||||||||||||||||||||||||||||||
| rollout_expert_indices=rollout_expert_indices, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _get_engine(self): | ||||||||||||||||||||||||||||||||
|
|
@@ -321,6 +334,14 @@ def _create_engine(self, *args, **kwargs): | |||||||||||||||||||||||||||||||
| enable_log_requests = kwargs.pop("enable_log_requests", False) | ||||||||||||||||||||||||||||||||
| max_log_len = kwargs.pop("max_log_len", None) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Log if enable_return_routed_experts is being passed | ||||||||||||||||||||||||||||||||
| if "enable_return_routed_experts" in kwargs: | ||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||
| f"DEBUG: enable_return_routed_experts={kwargs['enable_return_routed_experts']} is being passed to AsyncEngineArgs" | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| logger.warning("DEBUG: enable_return_routed_experts is NOT in kwargs") | ||||||||||||||||||||||||||||||||
|
Comment on lines
+337
to
+343
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if version.parse(vllm.__version__) >= version.parse("0.10.0"): | ||||||||||||||||||||||||||||||||
| engine_args = vllm.AsyncEngineArgs(enable_log_requests=enable_log_requests, **kwargs) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 Wrong comparison operator (
== 0instead of> 0) prevents rollout_expert_indices from being set to NoneIn
_postprocess_outputs, line 171 useslen(rollout_expert_indices) == 0but should uselen(rollout_expert_indices) > 0(or truthiness, like the logprobs check on line 168). With== 0: (1) if the list is empty,rollout_expert_indices[0]raises anIndexError; (2) if the list is non-empty (the normal case), the condition is alwaysFalse, so a list of allNonevalues (e.g.[None, None, ...]) is never collapsed toNone. This means whenenable_return_routed_expertsis disabled (the default), downstream code ininference_engine_client.py:169-171sees a truthy list ofNones, setsadd_rollout_expert_indices = True, and propagates[None, None, ...]instead ofNonethrough the pipeline.Comparison with correct pattern on line 168
Line 168 (correct):
if len(response_logprobs) and response_logprobs[0] is None:Line 171 (broken):
if len(rollout_expert_indices) == 0 and rollout_expert_indices[0] is None:Was this helpful? React with 👍 or 👎 to provide feedback.