-
Notifications
You must be signed in to change notification settings - Fork 624
Integrate DeepEP to experimental torchtitan #2107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
elfiegg
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
elfiegg:loss_bug
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
|
|
||
| from .moe_deepep import MoEWithDeepEP, get_deepep_buffer, get_hidden_bytes | ||
| from .expert_parallel import DeepEPExpertParallel | ||
|
|
||
| __all__ = [ | ||
| "MoEWithDeepEP", | ||
| "get_deepep_buffer", | ||
| "get_hidden_bytes", | ||
| "DeepEPExpertParallel", | ||
| ] | ||
|
|
||
| __version__ = "1.0.0" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from torchtitan.components.loss import build_cross_entropy_loss | ||
| from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
| from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
| from torchtitan.components.tokenizer import build_hf_tokenizer | ||
| from torchtitan.distributed.pipeline_parallel import pipeline_llm | ||
| from torchtitan.hf_datasets.text_datasets import build_text_dataloader | ||
| from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3StateDictAdapter | ||
| from torchtitan.protocols.train_spec import TrainSpec | ||
|
|
||
| from .model import DeepEPDeepSeekV3Model | ||
| from .parallelize import parallelize_deepseekv3 | ||
|
|
||
|
|
||
| def get_train_spec() -> TrainSpec: | ||
| """ | ||
| Get the training specification for DeepSeek-V3 with DeepEP. | ||
|
|
||
| Returns: | ||
| TrainSpec: Complete training specification including model, parallelization, | ||
| optimization, and data loading functions. | ||
| """ | ||
| return TrainSpec( | ||
| model_cls=DeepEPDeepSeekV3Model, | ||
| model_args=deepseekv3_args, | ||
| parallelize_fn=parallelize_deepseekv3, | ||
| pipelining_fn=pipeline_llm, | ||
| build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_text_dataloader, | ||
| build_tokenizer_fn=build_hf_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, | ||
| state_dict_adapter=DeepSeekV3StateDictAdapter, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "get_train_spec", | ||
| "DeepEPDeepSeekV3Model", | ||
| "parallelize_deepseekv3", | ||
| ] | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| """ | ||
| DeepSeek-V3 model wrapper for DeepEP experiments. | ||
|
|
||
| This module provides a DeepSeekV3 model class that is compatible with | ||
| DeepEP's MoE parallelization strategy. | ||
| """ | ||
|
|
||
| from torchtitan.models.deepseek_v3 import DeepSeekV3Model, DeepSeekV3ModelArgs | ||
|
|
||
|
|
||
| class DeepEPDeepSeekV3Model(DeepSeekV3Model): | ||
| """ | ||
| DeepSeek-V3 model with DeepEP-compatible initialization. | ||
|
|
||
| This class extends the base DeepSeekV3Model to ensure proper | ||
| initialization for DeepEP experiments. The main difference is | ||
| that MoE layers will be replaced with DeepEP versions during | ||
| the parallelization step. | ||
| """ | ||
|
|
||
| def __init__(self, model_args: DeepSeekV3ModelArgs): | ||
| super().__init__(model_args) | ||
| self.init_weights() | ||
|
|
||
| def init_weights(self, *args, **kwargs): | ||
| """Initialize model weights.""" | ||
| super().init_weights(*args, **kwargs) | ||
|
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of making it an experiment (which restricts it to a special version of deepseek_v3), I think we should integrate it directly in core.
We can have a factory method (e.g.
build_moe) which takes a string (e.g. "deep_ep") to dispatch to this version of MoE.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure that's a great idea! - once I confirm this works for larger models and improves perf
Regarding integrating directly to main - do we need to manage DeepEP dependency at all or we leave it to the users to install?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I prefer
instead of bundling it by default. We can explicitly mention this in try-catch when we do the import.