diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9db6f8036a73..eb9c71741951 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -365,6 +365,7 @@ th { | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | +| `Grok1ForCausalLM` | Grok2 | `xai-org/grok-2`. | ✅︎ | ✅︎ | ✅︎ | | `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | ✅︎ | ✅︎ | | `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | ✅︎ | diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index b1a61ade5364..2ac9761c9c00 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -786,6 +786,7 @@ def __init__( enable_eplb: bool = False, num_redundant_experts: int = 0, has_bias: bool = False, + use_presharded_weights: bool = False, ): super().__init__() if params_dtype is None: @@ -866,6 +867,7 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.apply_router_weight_on_input = apply_router_weight_on_input self.activation = activation + self.use_presharded_weights = use_presharded_weights if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1086,10 +1088,11 @@ def _load_w13(self, tp_rank: int, load_full: bool = False): + should_skip_sharding = self.use_presharded_weights or load_full # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - if not load_full: + if not should_skip_sharding: loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) @@ -1110,11 +1113,12 @@ def _load_w2(self, tp_rank: int, load_full: bool = False): + should_skip_sharding = self.use_presharded_weights or load_full # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] - if not load_full: + if not should_skip_sharding: loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, shard_size) diff --git a/vllm/model_executor/layers/rotary_embedding/grok1_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/grok1_scaling_rope.py new file mode 100644 index 000000000000..3c1b11f166a0 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/grok1_scaling_rope.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + +import torch + +from .base import RotaryEmbedding +from .common import (yarn_find_correction_range, yarn_get_mscale, + yarn_linear_ramp_mask) + + +class Grok1ScalingRotaryEmbedding(RotaryEmbedding): + """Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extra_method: str = "yarn_log", + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extra_method = extra_method + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + if self.extra_method in ["original"]: + inv_freq = inv_freq_extrapolation + elif self.extra_method in ["yarn", "yarn_linear"]: + inv_freq = (inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask) + elif self.extra_method == "yarn_log": + inv_freq = torch.exp( + torch.log(inv_freq_extrapolation) * inv_freq_mask + + torch.log(inv_freq_interpolation) * (1.0 - inv_freq_mask)) + elif self.extra_method == "theta_scale": + exponents = torch.arange(0, self.rotary_dim, 2, dtype=torch.float) + theta_scale_exponent = self.base**( + math.log(self.max_position_embeddings * self.scaling_factor / + (2 * math.pi)) / + math.log(self.max_position_embeddings / (2 * math.pi))) + inv_freq = torch.tensor( + 1.0 / (theta_scale_exponent**(exponents / self.rotary_dim)), + dtype=torch.float32, + ) + else: + raise ValueError( + f"Unknown extrapolation method: {self.extra_method}") + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + # cos = freqs.cos() * self.mscale + # sin = freqs.sin() * self.mscale + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index a59113438337..106ac4e904f7 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -33,15 +33,21 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (QKVParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.rotary_embedding.grok1_scaling_rope import ( + Grok1ScalingRotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -60,6 +66,41 @@ DEFAULT_EMBEDDING_MULTIPLIER_SCALE = 78.38367176906169 +class Grok1MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + reduce_results=True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + class Grok1MoE(nn.Module): """A tensor-parallel MoE implementation for Grok1 that shards each expert across all ranks. @@ -77,6 +118,8 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + reduce_results: bool = True, + use_presharded_weights: bool = False, prefix: str = ""): super().__init__() self.hidden_size = hidden_size @@ -94,11 +137,12 @@ def __init__(self, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - reduce_results=True, + reduce_results=reduce_results, renormalize=True, quant_config=quant_config, tp_size=tp_size, activation="gelu", + use_presharded_weights=use_presharded_weights, prefix=f"{prefix}.experts") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -112,19 +156,46 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(orig_shape) +def get_rope_scaling(config): + rope_type = getattr(config, "rope_type", None) + if rope_type: + original_max_position_embeddings = getattr( + config, "original_max_position_embeddings", None) + scaling_factor = getattr(config, "scaling_factor", None) + extrapolation_factor = getattr(config, "extrapolation_factor", 1.0) + attn_factor = getattr(config, "attn_factor", 1.0) + beta_fast = getattr(config, "beta_fast", 32) + beta_slow = getattr(config, "beta_slow", 1) + rope_scaling = { + "extra_method": rope_type, + "max_position_embeddings": original_max_position_embeddings, + "scaling_factor": scaling_factor, + "extrapolation_factor": extrapolation_factor, + "attn_factor": attn_factor, + "beta_fast": beta_fast, + "beta_slow": beta_slow, + "dtype": torch.float, + } + return rope_scaling + else: + return None + + class Grok1Attention(nn.Module): def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - max_position: int = 4096 * 32, - rope_theta: float = 10000, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - config=None, # Added config parameter + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + config=None, # Added config parameter + reduce_results: bool = True, + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -148,6 +219,8 @@ def __init__( self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta + rope_scaling = get_rope_scaling(config) + self.alt_stream = alt_stream or torch.cuda.Stream() self.qkv_proj = QKVParallelLinear( hidden_size, @@ -162,19 +235,34 @@ def __init__( self.total_num_heads * self.head_dim, hidden_size, bias=False, + reduce_results=reduce_results, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position, - base=int(self.rope_theta), - is_neox_style=True, - ) attn_logits_soft_cap = max( getattr(config, "attn_logit_softcapping", 30.0), 0.0) + self.rope_rotate_half_dims = getattr(config, "rope_rotate_half_dims", + False) + + if rope_scaling is not None: + self.rotary_emb = Grok1ScalingRotaryEmbedding( + self.head_dim, + rotary_dim=(self.head_dim if not self.rope_rotate_half_dims + else self.head_dim // 2), + base=int(self.rope_theta), + is_neox_style=True, + **rope_scaling, + ) + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=(self.head_dim if not self.rope_rotate_half_dims + else self.head_dim // 2), + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) self.attn = Attention(self.num_heads, self.head_dim, @@ -184,8 +272,9 @@ def __init__( quant_config=quant_config, logits_soft_cap=attn_logits_soft_cap, prefix=f"{prefix}.attn") - self.attn_multiplier = getattr(self.config, "attn_output_multiplier", - 1.0) if self.config else 1.0 + self.attn_multiplier = getattr( + self.config, "attn_output_multiplier", + DEFAULT_ATTN_OUTPUT_MULTIPLIER) if self.config else 1.0 def forward( self, @@ -208,10 +297,14 @@ def __init__( config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + load_presharded_moe: bool = False, + alt_stream: Optional[torch.cuda.Stream] = None, + is_grok1: bool = False, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.num_experts = config.num_experts if is_grok1 else config.num_local_experts # Check for fp8 quantization self.use_fp8 = False if quant_config is not None: @@ -219,31 +312,81 @@ def __init__( lambda: False)() if not self.use_fp8 and hasattr(quant_config, "is_fp8"): self.use_fp8 = quant_config.is_fp8 + self.residual_moe = getattr(config, "residual_moe", False) + self.alt_stream = alt_stream or torch.cuda.Stream() + self.is_grok1 = is_grok1 # Requires transformers > 4.32.0 # Default rope_theta value if not in config - rope_theta = 10000 - self.attn = Grok1Attention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - max_position=config.max_position_embeddings, - num_kv_heads=config.num_key_value_heads, - rope_theta=rope_theta, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - config=config) # Pass config to Grok1Attention - - # Grok1 uses "num_experts" in its config - num_experts = getattr(config, "num_experts", 8) - num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) - - self.moe_block = Grok1MoE(num_experts=num_experts, - top_k=num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - prefix=f"{prefix}.moe_block") + rope_theta = getattr(config, "rope_theta", 10000) + if is_grok1: + self.attn = Grok1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=(config.context_len if hasattr( + config, "context_len") else + config.max_position_embeddings), + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + config=config, # Pass config to Grok1Attention + alt_stream=self.alt_stream, + ) + else: + self.self_attn = Grok1Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=(config.context_len if hasattr( + config, "context_len") else + config.max_position_embeddings), + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + config=config, # Pass config to Grok1Attention + alt_stream=self.alt_stream, + ) + + if self.num_experts > 0: + if is_grok1: + self.moe_block = Grok1MoE( + num_experts=self.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=getattr( + config, + "moe_intermediate_size", + getattr(config, "intermediate_size", None), + ), + quant_config=quant_config, + reduce_results=not self.residual_moe, + use_presharded_weights=load_presharded_moe, + prefix=f"{prefix}.moe_block") + else: + self.block_sparse_moe = Grok1MoE( + num_experts=self.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=getattr( + config, + "moe_intermediate_size", + getattr(config, "intermediate_size", None), + ), + quant_config=quant_config, + reduce_results=not self.residual_moe, + use_presharded_weights=load_presharded_moe, + prefix=f"{prefix}.block_sparse_moe") + if self.residual_moe: + self.mlp = Grok1MLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.mlp") + else: + raise NotImplementedError("Number of experts must be > 0.") self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -254,6 +397,20 @@ def __init__( self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.num_experts > 0: + if self.residual_moe: + # NOTE: self.block_sparse_moe modifies the input in-place, + # so we have to call it later. Be aware of any possible related errors. + if get_tensor_model_parallel_world_size() > 1: + self.ffn = lambda x: tensor_model_parallel_all_reduce( + self.moe_with_rmoe(x)) + else: + self.ffn = self.moe_with_rmoe + else: + self.ffn = self.moe_block if is_grok1 else self.block_sparse_moe + else: + raise NotImplementedError("Number of experts must be > 0.") + def forward( self, positions: torch.Tensor, @@ -268,26 +425,48 @@ def forward( hidden_states, residual = self.pre_attn_norm( hidden_states, residual) - hidden_states = self.attn( - positions=positions, - hidden_states=hidden_states, - ) + if self.is_grok1: + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + ) + else: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) # Post attention normalization hidden_states = self.post_attn_norm(hidden_states) - # MoE block with normalization + # Fully Connected hidden_states, residual = self.pre_moe_norm(hidden_states, residual) - hidden_states = self.moe_block(hidden_states) + hidden_states = self.ffn(hidden_states) hidden_states = self.post_moe_norm(hidden_states) return hidden_states, residual + def moe_with_rmoe(self, x): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + mlp_result = self.mlp(x) + with torch.cuda.stream(self.alt_stream): + # moe should not be inplace because of stream race condition + moe_result = self.moe_block( + x) if self.is_grok1 else self.block_sparse_moe(x) + current_stream.wait_stream(self.alt_stream) + return (mlp_result + moe_result) / 1.4142135623730951 + @support_torch_compile class Grok1Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + load_presharded_moe: bool = False, + is_grok1: bool = False, + prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -313,17 +492,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, ) + self.alt_stream = torch.cuda.Stream() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Grok1DecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), + lambda prefix: Grok1DecoderLayer(config, + cache_config, + quant_config=quant_config, + load_presharded_moe= + load_presharded_moe, + alt_stream=self.alt_stream, + is_grok1=is_grok1, + prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + self.is_grok1 = is_grok1 def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) @@ -362,12 +548,11 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Map Grok1's unique expert parameter names to standard names - # Grok1 uses "num_experts" in its config - num_experts = getattr(self.config, "num_experts", 8) + num_experts = self.config.num_experts if self.is_grok1 else self.config.num_local_experts return FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="linear", # Grok1 specific - ckpt_down_proj_name="linear_1", # Grok1 specific - ckpt_up_proj_name="linear_v", # Grok1 specific + ckpt_gate_proj_name="linear" if self.is_grok1 else "w1", + ckpt_down_proj_name="linear_1" if self.is_grok1 else "w2", + ckpt_up_proj_name="linear_v" if self.is_grok1 else "w3", num_experts=num_experts) def load_weights(self, weights: Iterable[tuple[str, @@ -377,6 +562,8 @@ def load_weights(self, weights: Iterable[tuple[str, ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) @@ -478,12 +665,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + # Check the model architecture to handle parameter differences between Grok1 and Grok2. + architectures = getattr(config, "architectures", []) + is_grok1 = architectures and architectures[0] == "Grok1ModelForCausalLM" + num_experts = config.num_experts if is_grok1 else config.num_local_experts self.config = config self.lora_config = lora_config self.quant_config = quant_config + self.load_presharded_moe = ( + getattr(config, "load_presharded_moe", True) and num_experts > 0 + and get_tensor_model_parallel_world_size() > 1 and not is_grok1) self.model = Grok1Model(vllm_config=vllm_config, + load_presharded_moe=self.load_presharded_moe, + is_grok1=is_grok1, prefix=maybe_prefix(prefix, "model")) self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 38d300b03d2c..92cd45fb9ab1 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -88,6 +88,7 @@ "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 "GritLM": ("gritlm", "GritLM"), "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), + "Grok1ForCausalLM": ("grok1", "Grok1ForCausalLM"), "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"),