Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
356 changes: 332 additions & 24 deletions src/megatron/bridge/models/conversion/model_bridge.py

Large diffs are not rendered by default.

214 changes: 52 additions & 162 deletions src/megatron/bridge/models/deepseek/deepseek_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,110 +13,57 @@
# limitations under the License.
import warnings
from dataclasses import dataclass, field
from functools import partial
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from typing import Callable, List, Optional, Union

import torch
import torch.nn.functional as F
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.transformer_config import MLATransformerConfig
from megatron.bridge.models.mla_provider import MLAModelProvider
from megatron.bridge.utils.common_utils import get_rank_safe


try:
import transformer_engine # type: ignore # noqa: F401

HAVE_TE = True
except (ImportError, ModuleNotFoundError):
HAVE_TE = False

if TYPE_CHECKING:
from megatron.core.transformer import ModuleSpec

if HAVE_TE:
from megatron.core.utils import is_te_min_version
def _warn_deprecated(old_cls: str, new_cls: str = "MLAModelProvider") -> None:
if get_rank_safe() == 0:
warnings.warn(
f"{old_cls} is deprecated and will be removed in a future release. "
f"Use {new_cls} with MEGATRON_DEFAULTS in the bridge instead.",
DeprecationWarning,
stacklevel=3,
)


@dataclass
class DeepSeekModelProvider(MLATransformerConfig, GPTModelProvider):
"""
Base config for DeepSeek V2 and V3 models.
"""
class DeepSeekModelProvider(MLAModelProvider):
"""Deprecated alias for ``MLAModelProvider``.

transformer_layer_spec: Union["ModuleSpec", Callable[["GPTModelProvider"], "ModuleSpec"]] = partial(
get_gpt_decoder_block_spec, use_transformer_engine=HAVE_TE
)
Deprecated:
This alias remains for backward compatibility and will be removed in a
future release. Use ``MLAModelProvider`` instead.
"""

# Model
# Common DeepSeek defaults
normalization: str = "RMSNorm"
activation_func: Callable = F.silu
gated_linear_unit: bool = True # swiglu
gated_linear_unit: bool = True
position_embedding_type: str = "rope"
add_bias_linear: bool = False
share_embeddings_and_output_weights: bool = False
num_attention_heads: int = 128
kv_channels: int = 128
max_position_embeddings: int = 4096
seq_length: int = 4096
rotary_base: float = 10000.0
make_vocab_size_divisible_by: int = 3200
mtp_num_layers: Optional[int] = None
mtp_loss_scaling_factor: Optional[float] = None

# Regularization
attention_dropout: float = 0.0
hidden_dropout: float = 0.0
qk_layernorm: bool = True

# MoE
bf16: bool = True
params_dtype: torch.dtype = torch.bfloat16
moe_grouped_gemm: bool = True
moe_router_pre_softmax: bool = True
moe_token_dispatcher_type: str = "alltoall"
moe_router_load_balancing_type: str = "seq_aux_loss"
moe_shared_expert_overlap: bool = True
moe_router_dtype: Optional[str] = "fp32"

# MLA
q_lora_rank: int = 1536
# MLA defaults
q_lora_rank: Optional[int] = 1536
kv_lora_rank: int = 512
qk_head_dim: int = 128
qk_pos_emb_head_dim: int = 64
v_head_dim: int = 128
rotary_scaling_factor: float = 40
mscale: float = 1.0
mscale_all_dim: float = 1.0

# Miscellaneous
init_method_std: float = 0.006
layernorm_epsilon: float = 1e-6
bf16: bool = True
params_dtype: torch.dtype = torch.bfloat16
async_tensor_model_parallel_allreduce: bool = True
attention_softmax_in_fp32: bool = False
persist_layer_norm: bool = True
num_layers_in_first_pipeline_stage: Optional[int] = None
num_layers_in_last_pipeline_stage: Optional[int] = None
account_for_embedding_in_pipeline_split: bool = False
account_for_loss_in_pipeline_split: bool = False

# MLA specific
multi_latent_attention: bool = True

# fusions
apply_rope_fusion: bool = False
bias_activation_fusion: bool = True
bias_dropout_fusion: bool = True
masked_softmax_fusion: bool = True
gradient_accumulation_fusion: bool = True
cross_entropy_loss_fusion: bool = True
cross_entropy_fusion_impl: str = "te"
moe_permute_fusion: bool = is_te_min_version("2.1.0") if HAVE_TE else False
def __post_init__(self) -> None:
_warn_deprecated("DeepSeekModelProvider")
super().__post_init__()


@dataclass
class DeepSeekV2ModelProvider(DeepSeekModelProvider):
class DeepSeekV2ModelProvider(MLAModelProvider):
"""
DeepSeek-V2 Model: https://github.com/deepseek-ai/DeepSeek-V2
"""
Expand All @@ -137,9 +84,13 @@ class DeepSeekV2ModelProvider(DeepSeekModelProvider):
mscale_all_dim: float = 0.707
vocab_size: int = 102400

def __post_init__(self) -> None:
_warn_deprecated("DeepSeekV2ModelProvider")
super().__post_init__()


@dataclass
class DeepSeekV2LiteModelProvider(DeepSeekV2ModelProvider):
class DeepSeekV2LiteModelProvider(MLAModelProvider):
"""
DeepSeek-V2-Lite Model: https://github.com/deepseek-ai/DeepSeek-V2
HuggingFace: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite
Expand All @@ -150,7 +101,7 @@ class DeepSeekV2LiteModelProvider(DeepSeekV2ModelProvider):
ffn_hidden_size: int = 10944
num_attention_heads: int = 16
kv_channels: int = 16
q_lora_rank: int = None
q_lora_rank: Optional[int] = None
num_moe_experts: int = 64
moe_ffn_hidden_size: int = 1408
moe_shared_expert_intermediate_size: int = 2816 # 1408 * 2 shared experts
Expand All @@ -159,18 +110,25 @@ class DeepSeekV2LiteModelProvider(DeepSeekV2ModelProvider):
moe_router_num_groups: int = 1
moe_router_group_topk: int = 1
moe_router_topk_scaling_factor: float = 1.0
mscale: float = 0.707
mscale_all_dim: float = 0.707
vocab_size: int = 102400

def __post_init__(self) -> None:
_warn_deprecated("DeepSeekV2LiteModelProvider")
super().__post_init__()


@dataclass
class DeepSeekV3ModelProvider(DeepSeekModelProvider):
class DeepSeekV3ModelProvider(MLAModelProvider):
"""
DeepSeek-V3 Model: https://github.com/deepseek-ai/DeepSeek-V3
"""

num_layers: int = 61
hidden_size: int = 7168
ffn_hidden_size: int = 18432
kv_channels: int = 128
num_moe_experts: int = 256
moe_ffn_hidden_size: int = 2048
moe_shared_expert_intermediate_size: int = 2048 # 2048 * 1 shared expert
Expand All @@ -190,9 +148,13 @@ class DeepSeekV3ModelProvider(DeepSeekModelProvider):
mscale_all_dim: float = 1.0
vocab_size: int = 129280

def __post_init__(self) -> None:
_warn_deprecated("DeepSeekV3ModelProvider")
super().__post_init__()


@dataclass
class MoonlightModelProvider16B(DeepSeekModelProvider):
class MoonlightModelProvider16B(MLAModelProvider):
"""
Moonlight-16B-A3B Model: https://github.com/moonshotai/Moonlight-16B-A3B

Expand Down Expand Up @@ -228,86 +190,14 @@ class MoonlightModelProvider16B(DeepSeekModelProvider):
rotary_percent: float = 1.0
vocab_size: int = 163842


# -----------------------------------------------------------------------------
# Deprecated aliases (to be removed in a future release)
# -----------------------------------------------------------------------------


def _warn_deprecated(old_cls: str, new_cls: str) -> None:
if get_rank_safe() == 0:
warnings.warn(
f"{old_cls} is deprecated and will be removed in a future release. Use {new_cls} instead.",
DeprecationWarning,
stacklevel=2,
)


@dataclass
class DeepSeekProvider(DeepSeekModelProvider):
"""Deprecated alias for ``DeepSeekModelProvider``.

Deprecated:
This alias remains for backward compatibility and will be removed in a
future release. Import and use ``DeepSeekModelProvider`` instead.
"""

def __post_init__(self) -> None:
_warn_deprecated("DeepSeekProvider", "DeepSeekModelProvider")
super().__post_init__()


@dataclass
class DeepSeekV2Provider(DeepSeekV2ModelProvider):
"""Deprecated alias for ``DeepSeekV2ModelProvider``.

Deprecated:
This alias remains for backward compatibility and will be removed in a
future release. Import and use ``DeepSeekV2ModelProvider`` instead.
"""

def __post_init__(self) -> None:
_warn_deprecated("DeepSeekV2Provider", "DeepSeekV2ModelProvider")
super().__post_init__()


@dataclass
class DeepSeekV2LiteProvider(DeepSeekV2LiteModelProvider):
"""Deprecated alias for ``DeepSeekV2LiteModelProvider``.

Deprecated:
This alias remains for backward compatibility and will be removed in a
future release. Import and use ``DeepSeekV2LiteModelProvider`` instead.
"""

def __post_init__(self) -> None:
_warn_deprecated("DeepSeekV2LiteProvider", "DeepSeekV2LiteModelProvider")
_warn_deprecated("MoonlightModelProvider16B")
super().__post_init__()


@dataclass
class DeepSeekV3Provider(DeepSeekV3ModelProvider):
"""Deprecated alias for ``DeepSeekV3ModelProvider``.

Deprecated:
This alias remains for backward compatibility and will be removed in a
future release. Import and use ``DeepSeekV3ModelProvider`` instead.
"""

def __post_init__(self) -> None:
_warn_deprecated("DeepSeekV3Provider", "DeepSeekV3ModelProvider")
super().__post_init__()


@dataclass
class MoonlightProvider(MoonlightModelProvider16B):
"""Deprecated alias for ``MoonlightModelProvider16B``.

Deprecated:
This alias remains for backward compatibility and will be removed in a
future release. Import and use ``MoonlightModelProvider16B`` instead.
"""

def __post_init__(self) -> None:
_warn_deprecated("MoonlightProvider", "MoonlightModelProvider16B")
super().__post_init__()
# Legacy aliases for backward compatibility
DeepSeekProvider = DeepSeekModelProvider
DeepSeekV2Provider = DeepSeekV2ModelProvider
DeepSeekV2LiteProvider = DeepSeekV2LiteModelProvider
DeepSeekV3Provider = DeepSeekV3ModelProvider
MoonlightProvider = MoonlightModelProvider16B
78 changes: 57 additions & 21 deletions src/megatron/bridge/models/deepseek/deepseek_v2_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,41 +12,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from functools import partial

from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
from megatron.core.models.gpt.gpt_model import GPTModel

from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
from megatron.bridge.models.deepseek.common import get_common_configs, get_common_mapping_list
from megatron.bridge.models.deepseek.deepseek_provider import DeepSeekV2ModelProvider
from megatron.bridge.models.deepseek.common import get_common_mapping_list
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
from megatron.bridge.models.mla_provider import MLAModelProvider


@MegatronModelBridge.register_bridge(source="DeepseekV2ForCausalLM", target=GPTModel)
class DeepSeekV2Bridge(MegatronModelBridge):
"""
Megatron Bridge for DeepSeek-V2.
try:
import transformer_engine # noqa: F401

As a user you would not use this bridge directly, but through `AutoBridge`.
HAVE_TE = True
except (ImportError, ModuleNotFoundError):
HAVE_TE = False

Example:
>>> from megatron.bridge import AutoBridge
>>> bridge = AutoBridge.from_hf_pretrained("deepseek-ai/DeepSeek-V2", trust_remote_code=True)
>>> provider = bridge.to_megatron_provider()
"""

def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> DeepSeekV2ModelProvider:
@MegatronModelBridge.register_bridge(
source="DeepseekV2ForCausalLM",
target=GPTModel,
provider=MLAModelProvider,
model_type="deepseek_v2",
)
class DeepSeekV2Bridge(MegatronModelBridge):
"""Megatron Bridge for DeepSeek-V2."""

def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MLAModelProvider:
provider = super().provider_bridge(hf_pretrained)
hf_config = hf_pretrained.config
configs = get_common_configs(hf_pretrained)

configs["fp16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16
configs["bf16"] = self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16
configs["params_dtype"] = self.dtype_from_hf(hf_config, default=torch.float32)
provider.transformer_layer_spec = partial(get_gpt_decoder_block_spec, use_transformer_engine=HAVE_TE)
provider.normalization = "RMSNorm"
provider.gated_linear_unit = True
provider.position_embedding_type = "rope"
provider.add_bias_linear = False
provider.share_embeddings_and_output_weights = False
provider.qk_layernorm = True
provider.multi_latent_attention = True

provider.moe_grouped_gemm = True
provider.moe_router_pre_softmax = True
provider.moe_token_dispatcher_type = "alltoall"
provider.moe_router_load_balancing_type = "seq_aux_loss"
provider.moe_shared_expert_overlap = True
provider.moe_router_dtype = "fp32"
provider.moe_permute_fusion = True

provider.apply_rope_fusion = False
provider.bias_activation_fusion = True
provider.bias_dropout_fusion = True
provider.cross_entropy_fusion_impl = "te"
provider.cross_entropy_loss_fusion = True
provider.masked_softmax_fusion = True
provider.persist_layer_norm = True
provider.async_tensor_model_parallel_allreduce = True
provider.gradient_accumulation_fusion = True

provider.hidden_dropout = 0.0
provider.attention_softmax_in_fp32 = False

provider.make_vocab_size_divisible_by = 3200
provider.seq_length = 4096

configs["make_vocab_size_divisible_by"] = 3200
configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha
provider.moe_layer_freq = [0] * hf_config.first_k_dense_replace + [1] * (
hf_config.num_hidden_layers - hf_config.first_k_dense_replace
)
provider.moe_shared_expert_intermediate_size = hf_config.moe_intermediate_size * hf_config.n_shared_experts

provider = DeepSeekV2ModelProvider(**configs)
return provider

def mapping_registry(self) -> MegatronMappingRegistry:
Expand Down
Loading