-
Notifications
You must be signed in to change notification settings - Fork 239
kimi k2 recipe intro #2097
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
kimi k2 recipe intro #2097
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
a61f868
kimi k2 recipe intro
malay-nagda 91877de
no dispatcher backend
malay-nagda 5195fbe
Merge branch 'main' into malay/kimi_k2_init
malay-nagda 785ae08
Merge branch 'main' into malay/kimi_k2_init
malay-nagda 2bce880
add back layout logic
malay-nagda a0793a6
Merge branch 'main' into malay/kimi_k2_init
malay-nagda 43ecbe4
add refactor related changes to perf scripts
malay-nagda d2aea7e
layout when vp=None
malay-nagda 7a6f3ba
copyright disclaimer
malay-nagda fceb8a8
correct layout call
malay-nagda 7c0f4c5
layout for h100
malay-nagda 33be8d1
Merge branch 'main' into malay/kimi_k2_init
malay-nagda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| try: | ||
| import megatron.bridge # noqa: F401 | ||
|
|
||
| HAVE_MEGATRON_BRIDGE = True | ||
| except ModuleNotFoundError: | ||
| HAVE_MEGATRON_BRIDGE = False | ||
|
|
||
| if HAVE_MEGATRON_BRIDGE: | ||
| from .kimi_llm_pretrain import ( | ||
| kimi_k2_pretrain_config_b200, | ||
| kimi_k2_pretrain_config_gb200, | ||
| kimi_k2_pretrain_config_gb300, | ||
| kimi_k2_pretrain_config_h100, | ||
| ) | ||
|
|
||
| from .kimi_workload_base_configs import ( | ||
| KIMI_K2_PRETRAIN_CONFIG_B200_BF16, | ||
| KIMI_K2_PRETRAIN_CONFIG_B200_FP8_CS, | ||
| KIMI_K2_PRETRAIN_CONFIG_B200_FP8_MX, | ||
| KIMI_K2_PRETRAIN_CONFIG_GB200_BF16, | ||
| KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_CS, | ||
| KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_MX, | ||
| KIMI_K2_PRETRAIN_CONFIG_GB300_BF16, | ||
| KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_CS, | ||
| KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_MX, | ||
| KIMI_K2_PRETRAIN_CONFIG_GB300_NVFP4, | ||
| KIMI_K2_PRETRAIN_CONFIG_H100_BF16, | ||
| KIMI_K2_PRETRAIN_CONFIG_H100_FP8_CS, | ||
| KIMI_K2_PRETRAIN_CONFIG_H100_FP8_SC, | ||
| ) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "KIMI_K2_PRETRAIN_CONFIG_B200_BF16", | ||
| "KIMI_K2_PRETRAIN_CONFIG_B200_FP8_CS", | ||
| "KIMI_K2_PRETRAIN_CONFIG_B200_FP8_MX", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB200_BF16", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_CS", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_MX", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB300_BF16", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_CS", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_MX", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB300_NVFP4", | ||
| "KIMI_K2_PRETRAIN_CONFIG_H100_BF16", | ||
| "KIMI_K2_PRETRAIN_CONFIG_H100_FP8_CS", | ||
| "KIMI_K2_PRETRAIN_CONFIG_H100_FP8_SC", | ||
| ] | ||
|
|
||
| if HAVE_MEGATRON_BRIDGE: | ||
| __all__.extend( | ||
| [ | ||
| "kimi_k2_pretrain_config_gb300", | ||
| "kimi_k2_pretrain_config_gb200", | ||
| "kimi_k2_pretrain_config_b200", | ||
| "kimi_k2_pretrain_config_h100", | ||
| ] | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
|
|
||
| from utils.overrides import set_workload_base_configs | ||
| from utils.precision import get_precision_config | ||
| from utils.utils import get_workload_base_config | ||
|
|
||
| from megatron.bridge.recipes.kimi.kimi_k2 import _get_kimi_k2_pipeline_layout | ||
| from megatron.bridge.recipes.kimi.kimi_k2 import kimi_k2_pretrain_config as pretrain_config | ||
| from megatron.bridge.training.config import ConfigContainer | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def set_kimi_k2_common_configs(cfg: ConfigContainer) -> None: | ||
| """Set common performance configurations for all Kimi-K2 configs.""" | ||
| cfg.model.seq_length = 4096 | ||
| cfg.dataset.sequence_length = 4096 | ||
|
|
||
| cfg.model.moe_router_fusion = True | ||
| cfg.model.recompute_granularity = "selective" | ||
| cfg.dist.enable_megatron_core_experimental = True | ||
|
|
||
| cfg.mixed_precision.grad_reduce_in_fp32 = False | ||
| cfg.ddp.grad_reduce_in_fp32 = False | ||
|
|
||
| cfg.model.moe_router_force_load_balancing = True | ||
| cfg.model.qk_clip = True | ||
|
|
||
|
|
||
| def kimi_k2_pretrain_config_gb300( | ||
| precision: str = "bf16", mock: bool = True, config_variant: str = "v1" | ||
| ) -> ConfigContainer: | ||
| """GB300, baseline config.""" | ||
| base_cfg = get_workload_base_config( | ||
| model_family_name="kimi", | ||
| model_recipe_name="kimi_k2", | ||
| gpu="gb300", | ||
| compute_dtype=precision.upper(), | ||
| task="pretrain", | ||
| config_variant=config_variant, | ||
| ) | ||
|
|
||
| cfg = pretrain_config() | ||
| precision_config = get_precision_config(precision) | ||
| cfg.mixed_precision = precision_config | ||
|
|
||
| if base_cfg.moe_flex_dispatcher_backend is not None: | ||
| cfg.model.moe_flex_dispatcher_backend = base_cfg.moe_flex_dispatcher_backend | ||
|
|
||
| if base_cfg.pp_layout: | ||
| cfg.model.pipeline_model_parallel_layout = base_cfg.pp_layout | ||
| else: | ||
| # Recompute layout based on updated PP/VP sizes | ||
| pp_size = base_cfg.pipeline_model_parallel_size | ||
| vp_size = base_cfg.virtual_pipeline_model_parallel_size | ||
| layout = _get_kimi_k2_pipeline_layout(pp_size, vp_size) | ||
| cfg.model.pipeline_model_parallel_layout = layout | ||
|
|
||
| set_kimi_k2_common_configs(cfg) | ||
| set_workload_base_configs(cfg, base_cfg) | ||
|
|
||
| cfg.comm_overlap.overlap_grad_reduce = True | ||
|
|
||
| # Setting num_workers and pin_memory to 0 and False respectively gives better performance. | ||
| # we are debugging this and might change this in the future. | ||
| cfg.dataset.num_workers = 0 | ||
| cfg.dataset.pin_memory = False | ||
|
|
||
| return cfg | ||
|
|
||
|
|
||
| def kimi_k2_pretrain_config_gb200( | ||
| precision: str = "bf16", mock: bool = True, config_variant: str = "v1" | ||
| ) -> ConfigContainer: | ||
| """GB200, baseline config.""" | ||
| base_cfg = get_workload_base_config( | ||
| model_family_name="kimi", | ||
| model_recipe_name="kimi_k2", | ||
| gpu="gb200", | ||
| compute_dtype=precision.upper(), | ||
| task="pretrain", | ||
| config_variant=config_variant, | ||
| ) | ||
|
|
||
| cfg = pretrain_config() | ||
| precision_config = get_precision_config(precision) | ||
| cfg.mixed_precision = precision_config | ||
|
|
||
| if base_cfg.moe_flex_dispatcher_backend is not None: | ||
| cfg.model.moe_flex_dispatcher_backend = base_cfg.moe_flex_dispatcher_backend | ||
|
|
||
| if base_cfg.pp_layout: | ||
| cfg.model.pipeline_model_parallel_layout = base_cfg.pp_layout | ||
| else: | ||
| # Recompute layout based on updated PP/VP sizes | ||
| pp_size = base_cfg.pipeline_model_parallel_size | ||
| vp_size = base_cfg.virtual_pipeline_model_parallel_size | ||
| layout = _get_kimi_k2_pipeline_layout(pp_size, vp_size) | ||
| cfg.model.pipeline_model_parallel_layout = layout | ||
|
|
||
| set_kimi_k2_common_configs(cfg) | ||
| set_workload_base_configs(cfg, base_cfg) | ||
|
|
||
| cfg.comm_overlap.overlap_grad_reduce = True | ||
|
|
||
| # Setting num_workers and pin_memory to 0 and False respectively gives better performance. | ||
| # we are debugging this and might change this in the future. | ||
| cfg.dataset.num_workers = 0 | ||
| cfg.dataset.pin_memory = False | ||
|
|
||
| return cfg | ||
|
|
||
|
|
||
| def kimi_k2_pretrain_config_b200( | ||
| precision: str = "bf16", mock: bool = True, config_variant: str = "v1" | ||
| ) -> ConfigContainer: | ||
| """B200, baseline config.""" | ||
| base_cfg = get_workload_base_config( | ||
| model_family_name="kimi", | ||
| model_recipe_name="kimi_k2", | ||
| gpu="b200", | ||
| compute_dtype=precision.upper(), | ||
| task="pretrain", | ||
| config_variant=config_variant, | ||
| ) | ||
|
|
||
| cfg = pretrain_config() | ||
| precision_config = get_precision_config(precision) | ||
| cfg.mixed_precision = precision_config | ||
|
|
||
| if base_cfg.moe_flex_dispatcher_backend is not None: | ||
| cfg.model.moe_flex_dispatcher_backend = base_cfg.moe_flex_dispatcher_backend | ||
|
|
||
| if base_cfg.pp_layout: | ||
| cfg.model.pipeline_model_parallel_layout = base_cfg.pp_layout | ||
| else: | ||
| # Recompute layout based on updated PP/VP sizes | ||
| pp_size = base_cfg.pipeline_model_parallel_size | ||
| vp_size = base_cfg.virtual_pipeline_model_parallel_size | ||
| layout = _get_kimi_k2_pipeline_layout(pp_size, vp_size) | ||
| cfg.model.pipeline_model_parallel_layout = layout | ||
|
|
||
| set_kimi_k2_common_configs(cfg) | ||
| set_workload_base_configs(cfg, base_cfg) | ||
|
|
||
| cfg.comm_overlap.overlap_grad_reduce = True | ||
|
|
||
| return cfg | ||
|
|
||
|
|
||
| def kimi_k2_pretrain_config_h100( | ||
| precision: str = "bf16", mock: bool = True, config_variant: str = "v1" | ||
| ) -> ConfigContainer: | ||
| """H100, baseline config.""" | ||
| base_cfg = get_workload_base_config( | ||
| model_family_name="kimi", | ||
| model_recipe_name="kimi_k2", | ||
| gpu="h100", | ||
| compute_dtype=precision.upper(), | ||
| task="pretrain", | ||
| config_variant=config_variant, | ||
| ) | ||
|
|
||
| cfg = pretrain_config() | ||
| precision_config = get_precision_config(precision) | ||
| cfg.mixed_precision = precision_config | ||
|
|
||
| if base_cfg.moe_flex_dispatcher_backend is not None: | ||
| cfg.model.moe_flex_dispatcher_backend = base_cfg.moe_flex_dispatcher_backend | ||
|
|
||
| if base_cfg.pp_layout: | ||
| cfg.model.pipeline_model_parallel_layout = base_cfg.pp_layout | ||
| else: | ||
| # Recompute layout based on updated PP/VP sizes | ||
| pp_size = base_cfg.pipeline_model_parallel_size | ||
| vp_size = base_cfg.virtual_pipeline_model_parallel_size | ||
| layout = _get_kimi_k2_pipeline_layout(pp_size, vp_size) | ||
| cfg.model.pipeline_model_parallel_layout = layout | ||
|
|
||
| set_kimi_k2_common_configs(cfg) | ||
| set_workload_base_configs(cfg, base_cfg) | ||
|
|
||
| # Disabling to avoid functional errors. TODO: Test with it enabled and keep it enabled if it works. | ||
| cfg.comm_overlap.overlap_grad_reduce = False | ||
|
|
||
| return cfg |
109 changes: 109 additions & 0 deletions
109
scripts/performance/configs/kimi/kimi_workload_base_configs.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Workload base presets for Kimi-K2 performance configs.""" | ||
|
|
||
| from dataclasses import replace | ||
|
|
||
| from utils.utils import WorkloadBaseConfig | ||
|
|
||
|
|
||
| BASE_KIMI_K2_CONFIG = WorkloadBaseConfig( | ||
| expert_tensor_parallel_size=1, | ||
| ) | ||
|
|
||
|
|
||
| KIMI_K2_PRETRAIN_CONFIG_GB300 = replace( | ||
| BASE_KIMI_K2_CONFIG, | ||
| num_gpus=256, | ||
| global_batch_size=2048, | ||
| pipeline_model_parallel_size=4, | ||
| virtual_pipeline_model_parallel_size=4, | ||
| expert_model_parallel_size=64, | ||
| moe_flex_dispatcher_backend="hybridep", | ||
| moe_a2a_overlap=False, | ||
| cuda_graph_impl="transformer_engine", | ||
| cuda_graph_scope=["attn", "moe_router", "moe_preprocess"], | ||
| recompute_modules=["moe_act"], | ||
| ) | ||
| KIMI_K2_PRETRAIN_CONFIG_GB300_BF16 = KIMI_K2_PRETRAIN_CONFIG_GB300 | ||
| KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_CS = KIMI_K2_PRETRAIN_CONFIG_GB300 | ||
| KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_MX = KIMI_K2_PRETRAIN_CONFIG_GB300 | ||
| KIMI_K2_PRETRAIN_CONFIG_GB300_NVFP4 = KIMI_K2_PRETRAIN_CONFIG_GB300 | ||
|
|
||
|
|
||
| KIMI_K2_PRETRAIN_CONFIG_GB200 = replace( | ||
| BASE_KIMI_K2_CONFIG, | ||
| num_gpus=256, | ||
| global_batch_size=2048, | ||
| pipeline_model_parallel_size=4, | ||
| virtual_pipeline_model_parallel_size=4, | ||
| expert_model_parallel_size=64, | ||
| moe_flex_dispatcher_backend="hybridep", | ||
| moe_a2a_overlap=False, | ||
| recompute_modules=["mla_up_proj"], | ||
| cuda_graph_impl="transformer_engine", | ||
| cuda_graph_scope=["moe_router", "moe_preprocess"], | ||
| ) | ||
| KIMI_K2_PRETRAIN_CONFIG_GB200_BF16 = KIMI_K2_PRETRAIN_CONFIG_GB200 | ||
| KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_CS = KIMI_K2_PRETRAIN_CONFIG_GB200 | ||
| KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_MX = KIMI_K2_PRETRAIN_CONFIG_GB200 | ||
|
|
||
|
|
||
| KIMI_K2_PRETRAIN_CONFIG_B200 = replace( | ||
| BASE_KIMI_K2_CONFIG, | ||
| num_gpus=256, | ||
| pipeline_model_parallel_size=16, | ||
| expert_model_parallel_size=16, | ||
| global_batch_size=2048, | ||
| recompute_modules=["mla_up_proj"], | ||
| moe_a2a_overlap=False, | ||
| ) | ||
| KIMI_K2_PRETRAIN_CONFIG_B200_BF16 = KIMI_K2_PRETRAIN_CONFIG_B200 | ||
| KIMI_K2_PRETRAIN_CONFIG_B200_FP8_CS = KIMI_K2_PRETRAIN_CONFIG_B200 | ||
| KIMI_K2_PRETRAIN_CONFIG_B200_FP8_MX = KIMI_K2_PRETRAIN_CONFIG_B200 | ||
|
|
||
|
|
||
| KIMI_K2_PRETRAIN_CONFIG_H100 = replace( | ||
| BASE_KIMI_K2_CONFIG, | ||
| num_gpus=1024, | ||
| tensor_model_parallel_size=8, | ||
| pipeline_model_parallel_size=16, | ||
| virtual_pipeline_model_parallel_size=2, | ||
| expert_model_parallel_size=64, | ||
| global_batch_size=8192, | ||
| recompute_modules=["mla_up_proj", "mlp"], | ||
| moe_a2a_overlap=False, | ||
| pp_layout="Et|(tt|)*30L", | ||
| ) | ||
| KIMI_K2_PRETRAIN_CONFIG_H100_BF16 = KIMI_K2_PRETRAIN_CONFIG_H100 | ||
| KIMI_K2_PRETRAIN_CONFIG_H100_FP8_CS = KIMI_K2_PRETRAIN_CONFIG_H100 | ||
| KIMI_K2_PRETRAIN_CONFIG_H100_FP8_SC = KIMI_K2_PRETRAIN_CONFIG_H100 | ||
malay-nagda marked this conversation as resolved.
Show resolved
Hide resolved
malay-nagda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| __all__ = [ | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB300_BF16", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_CS", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_MX", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB300_NVFP4", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB200_BF16", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_CS", | ||
| "KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_MX", | ||
| "KIMI_K2_PRETRAIN_CONFIG_B200_BF16", | ||
| "KIMI_K2_PRETRAIN_CONFIG_B200_FP8_CS", | ||
| "KIMI_K2_PRETRAIN_CONFIG_B200_FP8_MX", | ||
| "KIMI_K2_PRETRAIN_CONFIG_H100_BF16", | ||
| "KIMI_K2_PRETRAIN_CONFIG_H100_FP8_CS", | ||
| "KIMI_K2_PRETRAIN_CONFIG_H100_FP8_SC", | ||
| ] | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.