Skip to content
71 changes: 71 additions & 0 deletions scripts/performance/configs/kimi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

try:
import megatron.bridge # noqa: F401

HAVE_MEGATRON_BRIDGE = True
except ModuleNotFoundError:
HAVE_MEGATRON_BRIDGE = False

if HAVE_MEGATRON_BRIDGE:
from .kimi_llm_pretrain import (
kimi_k2_pretrain_config_b200,
kimi_k2_pretrain_config_gb200,
kimi_k2_pretrain_config_gb300,
kimi_k2_pretrain_config_h100,
)

from .kimi_workload_base_configs import (
KIMI_K2_PRETRAIN_CONFIG_B200_BF16,
KIMI_K2_PRETRAIN_CONFIG_B200_FP8_CS,
KIMI_K2_PRETRAIN_CONFIG_B200_FP8_MX,
KIMI_K2_PRETRAIN_CONFIG_GB200_BF16,
KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_CS,
KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_MX,
KIMI_K2_PRETRAIN_CONFIG_GB300_BF16,
KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_CS,
KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_MX,
KIMI_K2_PRETRAIN_CONFIG_GB300_NVFP4,
KIMI_K2_PRETRAIN_CONFIG_H100_BF16,
KIMI_K2_PRETRAIN_CONFIG_H100_FP8_CS,
KIMI_K2_PRETRAIN_CONFIG_H100_FP8_SC,
)


__all__ = [
"KIMI_K2_PRETRAIN_CONFIG_B200_BF16",
"KIMI_K2_PRETRAIN_CONFIG_B200_FP8_CS",
"KIMI_K2_PRETRAIN_CONFIG_B200_FP8_MX",
"KIMI_K2_PRETRAIN_CONFIG_GB200_BF16",
"KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_CS",
"KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_MX",
"KIMI_K2_PRETRAIN_CONFIG_GB300_BF16",
"KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_CS",
"KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_MX",
"KIMI_K2_PRETRAIN_CONFIG_GB300_NVFP4",
"KIMI_K2_PRETRAIN_CONFIG_H100_BF16",
"KIMI_K2_PRETRAIN_CONFIG_H100_FP8_CS",
"KIMI_K2_PRETRAIN_CONFIG_H100_FP8_SC",
]

if HAVE_MEGATRON_BRIDGE:
__all__.extend(
[
"kimi_k2_pretrain_config_gb300",
"kimi_k2_pretrain_config_gb200",
"kimi_k2_pretrain_config_b200",
"kimi_k2_pretrain_config_h100",
]
)
201 changes: 201 additions & 0 deletions scripts/performance/configs/kimi/kimi_llm_pretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

from utils.overrides import set_workload_base_configs
from utils.precision import get_precision_config
from utils.utils import get_workload_base_config

from megatron.bridge.recipes.kimi.kimi_k2 import _get_kimi_k2_pipeline_layout
from megatron.bridge.recipes.kimi.kimi_k2 import kimi_k2_pretrain_config as pretrain_config
from megatron.bridge.training.config import ConfigContainer


logger = logging.getLogger(__name__)


def set_kimi_k2_common_configs(cfg: ConfigContainer) -> None:
"""Set common performance configurations for all Kimi-K2 configs."""
cfg.model.seq_length = 4096
cfg.dataset.sequence_length = 4096

cfg.model.moe_router_fusion = True
cfg.model.recompute_granularity = "selective"
cfg.dist.enable_megatron_core_experimental = True

cfg.mixed_precision.grad_reduce_in_fp32 = False
cfg.ddp.grad_reduce_in_fp32 = False

cfg.model.moe_router_force_load_balancing = True
cfg.model.qk_clip = True


def kimi_k2_pretrain_config_gb300(
precision: str = "bf16", mock: bool = True, config_variant: str = "v1"
) -> ConfigContainer:
"""GB300, baseline config."""
base_cfg = get_workload_base_config(
model_family_name="kimi",
model_recipe_name="kimi_k2",
gpu="gb300",
compute_dtype=precision.upper(),
task="pretrain",
config_variant=config_variant,
)

cfg = pretrain_config()
precision_config = get_precision_config(precision)
cfg.mixed_precision = precision_config

if base_cfg.moe_flex_dispatcher_backend is not None:
cfg.model.moe_flex_dispatcher_backend = base_cfg.moe_flex_dispatcher_backend

if base_cfg.pp_layout:
cfg.model.pipeline_model_parallel_layout = base_cfg.pp_layout
else:
# Recompute layout based on updated PP/VP sizes
pp_size = base_cfg.pipeline_model_parallel_size
vp_size = base_cfg.virtual_pipeline_model_parallel_size
layout = _get_kimi_k2_pipeline_layout(pp_size, vp_size)
cfg.model.pipeline_model_parallel_layout = layout

set_kimi_k2_common_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

cfg.comm_overlap.overlap_grad_reduce = True

# Setting num_workers and pin_memory to 0 and False respectively gives better performance.
# we are debugging this and might change this in the future.
cfg.dataset.num_workers = 0
cfg.dataset.pin_memory = False

return cfg


def kimi_k2_pretrain_config_gb200(
precision: str = "bf16", mock: bool = True, config_variant: str = "v1"
) -> ConfigContainer:
"""GB200, baseline config."""
base_cfg = get_workload_base_config(
model_family_name="kimi",
model_recipe_name="kimi_k2",
gpu="gb200",
compute_dtype=precision.upper(),
task="pretrain",
config_variant=config_variant,
)

cfg = pretrain_config()
precision_config = get_precision_config(precision)
cfg.mixed_precision = precision_config

if base_cfg.moe_flex_dispatcher_backend is not None:
cfg.model.moe_flex_dispatcher_backend = base_cfg.moe_flex_dispatcher_backend

if base_cfg.pp_layout:
cfg.model.pipeline_model_parallel_layout = base_cfg.pp_layout
else:
# Recompute layout based on updated PP/VP sizes
pp_size = base_cfg.pipeline_model_parallel_size
vp_size = base_cfg.virtual_pipeline_model_parallel_size
layout = _get_kimi_k2_pipeline_layout(pp_size, vp_size)
cfg.model.pipeline_model_parallel_layout = layout

set_kimi_k2_common_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

cfg.comm_overlap.overlap_grad_reduce = True

# Setting num_workers and pin_memory to 0 and False respectively gives better performance.
# we are debugging this and might change this in the future.
cfg.dataset.num_workers = 0
cfg.dataset.pin_memory = False

return cfg


def kimi_k2_pretrain_config_b200(
precision: str = "bf16", mock: bool = True, config_variant: str = "v1"
) -> ConfigContainer:
"""B200, baseline config."""
base_cfg = get_workload_base_config(
model_family_name="kimi",
model_recipe_name="kimi_k2",
gpu="b200",
compute_dtype=precision.upper(),
task="pretrain",
config_variant=config_variant,
)

cfg = pretrain_config()
precision_config = get_precision_config(precision)
cfg.mixed_precision = precision_config

if base_cfg.moe_flex_dispatcher_backend is not None:
cfg.model.moe_flex_dispatcher_backend = base_cfg.moe_flex_dispatcher_backend

if base_cfg.pp_layout:
cfg.model.pipeline_model_parallel_layout = base_cfg.pp_layout
else:
# Recompute layout based on updated PP/VP sizes
pp_size = base_cfg.pipeline_model_parallel_size
vp_size = base_cfg.virtual_pipeline_model_parallel_size
layout = _get_kimi_k2_pipeline_layout(pp_size, vp_size)
cfg.model.pipeline_model_parallel_layout = layout

set_kimi_k2_common_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

cfg.comm_overlap.overlap_grad_reduce = True

return cfg


def kimi_k2_pretrain_config_h100(
precision: str = "bf16", mock: bool = True, config_variant: str = "v1"
) -> ConfigContainer:
"""H100, baseline config."""
base_cfg = get_workload_base_config(
model_family_name="kimi",
model_recipe_name="kimi_k2",
gpu="h100",
compute_dtype=precision.upper(),
task="pretrain",
config_variant=config_variant,
)

cfg = pretrain_config()
precision_config = get_precision_config(precision)
cfg.mixed_precision = precision_config

if base_cfg.moe_flex_dispatcher_backend is not None:
cfg.model.moe_flex_dispatcher_backend = base_cfg.moe_flex_dispatcher_backend

if base_cfg.pp_layout:
cfg.model.pipeline_model_parallel_layout = base_cfg.pp_layout
else:
# Recompute layout based on updated PP/VP sizes
pp_size = base_cfg.pipeline_model_parallel_size
vp_size = base_cfg.virtual_pipeline_model_parallel_size
layout = _get_kimi_k2_pipeline_layout(pp_size, vp_size)
cfg.model.pipeline_model_parallel_layout = layout

set_kimi_k2_common_configs(cfg)
set_workload_base_configs(cfg, base_cfg)

# Disabling to avoid functional errors. TODO: Test with it enabled and keep it enabled if it works.
cfg.comm_overlap.overlap_grad_reduce = False

return cfg
109 changes: 109 additions & 0 deletions scripts/performance/configs/kimi/kimi_workload_base_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Workload base presets for Kimi-K2 performance configs."""

from dataclasses import replace

from utils.utils import WorkloadBaseConfig


BASE_KIMI_K2_CONFIG = WorkloadBaseConfig(
expert_tensor_parallel_size=1,
)


KIMI_K2_PRETRAIN_CONFIG_GB300 = replace(
BASE_KIMI_K2_CONFIG,
num_gpus=256,
global_batch_size=2048,
pipeline_model_parallel_size=4,
virtual_pipeline_model_parallel_size=4,
expert_model_parallel_size=64,
moe_flex_dispatcher_backend="hybridep",
moe_a2a_overlap=False,
cuda_graph_impl="transformer_engine",
cuda_graph_scope=["attn", "moe_router", "moe_preprocess"],
recompute_modules=["moe_act"],
)
KIMI_K2_PRETRAIN_CONFIG_GB300_BF16 = KIMI_K2_PRETRAIN_CONFIG_GB300
KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_CS = KIMI_K2_PRETRAIN_CONFIG_GB300
KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_MX = KIMI_K2_PRETRAIN_CONFIG_GB300
KIMI_K2_PRETRAIN_CONFIG_GB300_NVFP4 = KIMI_K2_PRETRAIN_CONFIG_GB300


KIMI_K2_PRETRAIN_CONFIG_GB200 = replace(
BASE_KIMI_K2_CONFIG,
num_gpus=256,
global_batch_size=2048,
pipeline_model_parallel_size=4,
virtual_pipeline_model_parallel_size=4,
expert_model_parallel_size=64,
moe_flex_dispatcher_backend="hybridep",
moe_a2a_overlap=False,
recompute_modules=["mla_up_proj"],
cuda_graph_impl="transformer_engine",
cuda_graph_scope=["moe_router", "moe_preprocess"],
)
KIMI_K2_PRETRAIN_CONFIG_GB200_BF16 = KIMI_K2_PRETRAIN_CONFIG_GB200
KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_CS = KIMI_K2_PRETRAIN_CONFIG_GB200
KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_MX = KIMI_K2_PRETRAIN_CONFIG_GB200


KIMI_K2_PRETRAIN_CONFIG_B200 = replace(
BASE_KIMI_K2_CONFIG,
num_gpus=256,
pipeline_model_parallel_size=16,
expert_model_parallel_size=16,
global_batch_size=2048,
recompute_modules=["mla_up_proj"],
moe_a2a_overlap=False,
)
KIMI_K2_PRETRAIN_CONFIG_B200_BF16 = KIMI_K2_PRETRAIN_CONFIG_B200
KIMI_K2_PRETRAIN_CONFIG_B200_FP8_CS = KIMI_K2_PRETRAIN_CONFIG_B200
KIMI_K2_PRETRAIN_CONFIG_B200_FP8_MX = KIMI_K2_PRETRAIN_CONFIG_B200


KIMI_K2_PRETRAIN_CONFIG_H100 = replace(
BASE_KIMI_K2_CONFIG,
num_gpus=1024,
tensor_model_parallel_size=8,
pipeline_model_parallel_size=16,
virtual_pipeline_model_parallel_size=2,
expert_model_parallel_size=64,
global_batch_size=8192,
recompute_modules=["mla_up_proj", "mlp"],
moe_a2a_overlap=False,
pp_layout="Et|(tt|)*30L",
)
KIMI_K2_PRETRAIN_CONFIG_H100_BF16 = KIMI_K2_PRETRAIN_CONFIG_H100
KIMI_K2_PRETRAIN_CONFIG_H100_FP8_CS = KIMI_K2_PRETRAIN_CONFIG_H100
KIMI_K2_PRETRAIN_CONFIG_H100_FP8_SC = KIMI_K2_PRETRAIN_CONFIG_H100


__all__ = [
"KIMI_K2_PRETRAIN_CONFIG_GB300_BF16",
"KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_CS",
"KIMI_K2_PRETRAIN_CONFIG_GB300_FP8_MX",
"KIMI_K2_PRETRAIN_CONFIG_GB300_NVFP4",
"KIMI_K2_PRETRAIN_CONFIG_GB200_BF16",
"KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_CS",
"KIMI_K2_PRETRAIN_CONFIG_GB200_FP8_MX",
"KIMI_K2_PRETRAIN_CONFIG_B200_BF16",
"KIMI_K2_PRETRAIN_CONFIG_B200_FP8_CS",
"KIMI_K2_PRETRAIN_CONFIG_B200_FP8_MX",
"KIMI_K2_PRETRAIN_CONFIG_H100_BF16",
"KIMI_K2_PRETRAIN_CONFIG_H100_FP8_CS",
"KIMI_K2_PRETRAIN_CONFIG_H100_FP8_SC",
]
Loading
Loading