diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index bdafed15e9ef..64c530334067 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -21,6 +21,7 @@ Chunk, Concatenate, ErnieFuseAndSplitTextVisionExperts, + Force16BytesAlignment, MergeModulelist, Transpose, WeightConverter, @@ -40,6 +41,23 @@ def _build_checkpoint_conversion_mapping(): mapping = { + "gpt_oss": [ + # NOTE: These converters are only applied if the model is being loaded from pre-dequantized checkpoint. + # If you are dequantizing the model on the fly, these converters will be ignored because the tensors + # that match these patterns are only created after dequantization. + # That's not an issue for now since the dequantization converters already ensure 16 bytes alignment + # by enforcing contiguity. + WeightConverter( + source_patterns="mlp.experts.gate_up_proj$", + target_patterns="mlp.experts.gate_up_proj", + operations=[Force16BytesAlignment()], + ), + WeightConverter( + source_patterns="mlp.experts.down_proj$", + target_patterns="mlp.experts.down_proj", + operations=[Force16BytesAlignment()], + ), + ], "mixtral": [ WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"), WeightConverter( @@ -348,6 +366,13 @@ def get_model_conversion_mapping( # Add the ones from the quantizer as well if provided if hf_quantizer is not None: + # NOTE: Since get_weight_conversions() only serve to dequantize, we would normally want to apply them first. + # However, for now it's not possible to cascade converters (i.e., applying model-specific conversions on top + # of tensors created by the dequantization conversions) + # This means that if a model has model-specific conversions and is being dequantized, the model-specific conversion + # that relies on tensors created by dequantization conversions will not be applied. + # GptOss example: with Mxfp4Config(dequantize=True), Force16BytesAlignment converters are ignored because the tensors + # "mlp.experts.gate_up_proj$" and "mlp.experts.down_proj$" are only created after dequantization conversions are applied. weight_conversions.extend(hf_quantizer.get_weight_conversions()) return weight_conversions diff --git a/src/transformers/core_model_loading.py b/src/transformers/core_model_loading.py index 92068f127368..e04244603286 100644 --- a/src/transformers/core_model_loading.py +++ b/src/transformers/core_model_loading.py @@ -440,6 +440,42 @@ def reverse_op(self) -> ConversionOps: return ErnieFuseAndSplitTextVisionExperts(stack_dim=self.stack_dim, concat_dim=self.concat_dim) +class Force16BytesAlignment(ConversionOps): + """ + Ensures that the given tensor is 16-bytes aligned in memory and clones it if not. + This garantees 16-bytes alignmenet for kernels / implementations that use TMA or SIMD instructions like torch._grouped_mm. + """ + + @torch.no_grad() + def convert( + self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str], **kwargs + ) -> dict[str, torch.Tensor]: + target_pattern = self.get_target_pattern(input_dict, source_patterns, target_patterns) + tensors = next(iter(input_dict.values())) + tensor = tensors[0] if isinstance(tensors, list) else tensors + tensor = tensor.clone() if tensor.data_ptr() % 16 != 0 else tensor + return {target_pattern: tensor} + + def get_target_pattern( + self, input_dict: dict[str, torch.Tensor], source_patterns: list[str], target_patterns: list[str] + ) -> str: + if len(input_dict) != 1: + raise ValueError("Undefined Operation encountered!") + # Here it's the first operation of a chain, so return the source + if len(target_patterns) > 1: + if len(source_patterns) == 1: + return source_patterns[0] + else: + raise ValueError("Undefined Operation encountered!") + # Here it's the only operation, or the last operation in a chain, so we return the target + else: + return target_patterns[0] + + @property + def reverse_op(self) -> ConversionOps: + return Force16BytesAlignment() + + @dataclass(slots=True) class WeightTransform: source_patterns: str | list[str] = field(init=True) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 12e4168ac0bd..1f2679573cd4 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -21,7 +21,6 @@ if is_torch_available(): import torch - # Examples of experts class with its eager mm implementation # class Experts(nn.Module): # """Collection of expert weights stored as 3D tensors.""" @@ -62,6 +61,40 @@ # return final_hidden_states +def _batched_linear( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + is_transposed: bool = False, +) -> torch.Tensor: + """Batched linear layer supporting optional bias and transposed weights. + + Args: + input (`torch.Tensor`): + Input tensor of shape (batch_size, input_dim). + weight (`torch.Tensor`): + Weight tensor of shape (batch_size, output_dim, input_dim) if transposed is `False`, + else of shape (batch_size, input_dim, output_dim). + bias (`torch.Tensor`, *optional*): + Bias tensor of shape (batch_size, output_dim). Default is `None`. + is_transposed (`bool`, *optional*, defaults to `False`): + Whether the weight tensor is transposed. + Returns: + `torch.Tensor`: Output tensor of shape (batch_size, output_dim). + """ + if is_transposed: + # (batch_size, 1, input_dim) @ (batch_size, input_dim, output_dim) -> (batch_size, 1, output_dim) -> (batch_size, output_dim) + out = torch.bmm(input.unsqueeze(1), weight).squeeze(1) + else: + # (batch_size, output_dim, input_dim) @ (batch_size, input_dim, 1) -> (batch_size, output_dim, 1) -> (batch_size, output_dim) + out = torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1) + + if bias is not None: + out = out + bias + + return out + + def batched_mm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -82,40 +115,42 @@ def batched_mm_experts_forward( # - (num_tokens, num_top_k) Qwen2MoE style # - (num_tokens, num_experts) DeepseekV2 style if top_k_weights.shape == (num_tokens, num_top_k): - sample_weights = top_k_weights.reshape(-1) # (S,) + sample_weights = top_k_weights elif top_k_weights.shape == (num_tokens, num_experts): - sample_weights = top_k_weights[token_idx, expert_ids] # (S,) + # TODO: routers that output full expert distribution + # should probably be corrected to output only top_k weights + sample_weights = top_k_weights[token_idx, expert_ids] else: raise ValueError( f"top_k_weights has an invalid/unsupported shape. It should be either (num_tokens, num_top_k)({num_tokens}, {num_top_k}) " f"or (num_tokens, num_experts)({num_tokens}, {num_experts}), but got {top_k_weights.shape}." ) + sample_weights = sample_weights.reshape(-1, 1) # (S, 1) # Get current hidden states for selected samples - current_hidden_states = hidden_states[token_idx] # (S, hidden_dim) + selected_hidden_states = hidden_states[token_idx] - # Select projection matrices for selected experts - selected_gate_up = self.gate_up_proj[expert_ids] # (S, hidden_dim, 2 * intermediate_dim) - selected_down = self.down_proj[expert_ids] # (S, hidden_dim, intermediate_dim) + # Select expert weights and biases for selected samples + selected_gate_up = self.gate_up_proj[expert_ids] + selected_down = self.down_proj[expert_ids] + selected_gate_up_bias = self.gate_up_proj_bias[expert_ids] if self.has_bias else None + selected_down_bias = self.down_proj_bias[expert_ids] if self.has_bias else None # --- Up projection per expert (batched) --- - gate_up_out = torch.bmm(selected_gate_up, current_hidden_states.unsqueeze(-1)).squeeze(-1) - if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: - gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids] - - # Split into gate and up components - gate, up = gate_up_out.chunk(2, dim=-1) # both have shape (S, intermediate_dim) + gate_up_out = _batched_linear( + selected_hidden_states, selected_gate_up, selected_gate_up_bias, is_transposed=self.is_transposed + ) # (S, 2 * intermediate_dim) - # Apply activation - hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) + # Apply gating + gated_out = self._apply_gate(gate_up_out) # (S, intermediate_dim) # --- Down projection per expert (batched) --- - out_per_sample = torch.bmm(selected_down, hidden_after_activation.unsqueeze(-1)).squeeze(-1) - if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: - out_per_sample = out_per_sample + self.down_proj_bias[expert_ids] + out_per_sample = _batched_linear( + gated_out, selected_down, selected_down_bias, is_transposed=self.is_transposed + ) # (S, hidden_dim) # Apply routing weights - out_per_sample = out_per_sample * sample_weights.unsqueeze(-1) # (S, hidden_dim) + out_per_sample = out_per_sample * sample_weights # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) @@ -124,6 +159,44 @@ def batched_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) +def _grouped_linear( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + offs: torch.Tensor | None = None, + is_transposed: bool = False, +) -> torch.Tensor: + """Grouped linear layer supporting optional bias and transposed weights. + + Args: + input (`torch.Tensor`): + Input tensor of shape (S, input_dim). + weight (`torch.Tensor`): + Weight tensor of shape (num_experts, output_dim, input_dim) if transposed is `False`, + else of shape (num_experts, input_dim, output_dim). + bias (`torch.Tensor`, *optional*): + Bias tensor of shape (num_experts, output_dim). Default is `None`. + offs (`torch.Tensor`, *optional*): + Offsets tensor indicating the boundaries of each group in the input tensor. + is_transposed (`bool`, *optional*, defaults to `False`): + Whether the weight tensor is transposed. + Returns: + `torch.Tensor`: Output tensor of shape (S, output_dim). + """ + if is_transposed: + # (S, input_dim) @ grouped (num_experts, input_dim, output_dim) -> (S, output_dim) + out = torch._grouped_mm(input, weight, offs=offs) + else: + # (S, input_dim) @ grouped (num_experts, output_dim, input_dim).T -> (S, output_dim) + out = torch._grouped_mm(input, weight.transpose(-2, -1), offs=offs) + + if bias is not None: + # We should be able to pass bias to the grouped_mm call, but it's not yet supported. + out = out + bias + + return out + + def grouped_mm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -145,57 +218,62 @@ def grouped_mm_experts_forward( expert_ids = top_k_index.reshape(-1) token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) - # Get permutation to group by expert - perm = torch.argsort(expert_ids, stable=True) - inv_perm = torch.argsort(perm, stable=True) - # Resolve routing weights per selected sample, allowing top_k_weights to be either: # - (num_tokens, num_top_k) Qwen2MoE style # - (num_tokens, num_experts) DeepseekV2 style if top_k_weights.shape == (num_tokens, num_top_k): - sample_weights = top_k_weights.reshape(-1) # (S,) + sample_weights = top_k_weights elif top_k_weights.shape == (num_tokens, num_experts): - sample_weights = top_k_weights[token_idx, expert_ids] # (S,) + # TODO: routers that output full expert distribution + # should probably be corrected to output only top_k weights + sample_weights = top_k_weights[token_idx, expert_ids] else: raise ValueError( f"top_k_weights has an invalid/unsupported shape. It should be either (num_tokens, num_top_k)({num_tokens}, {num_top_k}) " f"or (num_tokens, num_experts)({num_tokens}, {num_experts}), but got {top_k_weights.shape}." ) + sample_weights = sample_weights.reshape(-1, 1) # (S, 1) # Get current hidden states for selected samples - current_hidden_states = hidden_states[token_idx] # (S, hidden_dim) + selected_hidden_states = hidden_states[token_idx] - # Group by expert for grouped_mm + # Sort by expert for grouped processing + perm = torch.argsort(expert_ids) + inv_perm = torch.argsort(perm) expert_ids_g = expert_ids[perm] sample_weights_g = sample_weights[perm] - current_states_g = current_hidden_states[perm] + selected_hidden_states_g = selected_hidden_states[perm] + + # Select expert weights and biases for selected samples + # NOTE: We keep all experts here and rely on offsets to target the active ones. + # I have already implemented a version that only passes the active experts, but + # to do so I had to use torch.unique which breaks the graph capture (data-dependent). + # Also there were no speedup gains from it in my experiments, even in eager mode. + selected_gate_up = self.gate_up_proj + selected_down = self.down_proj + selected_gate_up_bias = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None + selected_down_bias = self.down_proj_bias[expert_ids_g] if self.has_bias else None # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues - # (grouped_mm_experts_forward still fails with cuda graphs but because of _grouped_mm internals) num_tokens_per_expert = torch.histc(expert_ids_g.float(), bins=num_experts, min=0, max=num_experts - 1) offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) - # --- Up projection per expert (grouped_mm) --- - gate_up_out = torch._grouped_mm(current_states_g, self.gate_up_proj.transpose(-2, -1), offs=offsets) - if hasattr(self, "gate_up_proj_bias") and self.gate_up_proj_bias is not None: - # we should be able to pass bias to the grouped_mm call, but it's still not fully supported - gate_up_out = gate_up_out + self.gate_up_proj_bias[expert_ids_g] - - # Split into gate and up components - gate, up = gate_up_out.chunk(2, dim=-1) # both have shape (S, intermediate_dim) + # --- Up projection per expert (grouped) --- + gate_up_out = _grouped_linear( + selected_hidden_states_g, selected_gate_up, selected_gate_up_bias, offsets, is_transposed=self.is_transposed + ) # (S, 2 * intermediate_dim) - # Apply activation - hidden_after_activation = self.act_fn(gate) * up # (S, intermediate_dim) + # Apply gating + gated_out = self._apply_gate(gate_up_out) # (S, intermediate_dim) - # --- Down projection per expert (grouped_mm) --- - out_per_sample_g = torch._grouped_mm(hidden_after_activation, self.down_proj.transpose(-2, -1), offs=offsets) - if hasattr(self, "down_proj_bias") and self.down_proj_bias is not None: - # we should be able to pass bias to the grouped_mm call, but it's still not fully supported - out_per_sample_g = out_per_sample_g + self.down_proj_bias[expert_ids_g] + # --- Down projection per expert (grouped) --- + out_per_sample_g = _grouped_linear( + gated_out, selected_down, selected_down_bias, offsets, is_transposed=self.is_transposed + ) # (S, hidden_dim) # Apply routing weights - out_per_sample_g = out_per_sample_g * sample_weights_g.unsqueeze(-1) + out_per_sample_g = out_per_sample_g * sample_weights_g # Restore original order out_per_sample = out_per_sample_g[inv_perm] @@ -219,24 +297,64 @@ class ExpertsInterface(GeneralInterface): ALL_EXPERTS_FUNCTIONS = ExpertsInterface() -def use_experts_implementation(experts_class: type[torch.nn.Module]) -> type[torch.nn.Module]: - original_init = experts_class.__init__ - original_forward = experts_class.forward - - @wraps(original_init) - def __init__(self, config, *args, **kwargs): - original_init(self, config, *args, **kwargs) - self.config = config - - @wraps(original_forward) - def forward(self, *args, **kwargs): - experts_forward = original_forward - - if self.config._experts_implementation != "eager": - experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation] - - return experts_forward(self, *args, **kwargs) - - experts_class.__init__ = __init__ - experts_class.forward = forward - return experts_class +def _default_apply_gate(self, gate_up_out: torch.Tensor) -> torch.Tensor: + """ + Default gating mechanism: splits the gate_up_out into gate and up parts, + applies the activation function to the gate part, and multiplies it with the up part. + Args: + gate_up_out (`torch.Tensor`): + The output tensor from the gate and up projection of shape (S, 2 * intermediate_dim). + Returns: + `torch.Tensor`: The gated output tensor of shape (S, intermediate_dim). + """ + gate, up = gate_up_out.chunk(2, dim=-1) # (S, intermediate_dim) + return self.act_fn(gate) * up # (S, intermediate_dim) + + +def use_experts_implementation( + experts_class: type[torch.nn.Module] | None = None, *, is_transposed: bool = False, has_bias: bool = False +) -> type[torch.nn.Module]: + """Decorator to modify experts class to support different experts implementations. + + Args: + experts_class (`type[torch.nn.Module]`, *optional*): + The experts class to modify. If not provided, returns a decorator that can be applied to the class. + is_transposed (`bool`, *optional*, defaults to `False`): + Whether the expert weights are stored in transposed format. + has_bias (`bool`, *optional*, defaults to `False`): + Whether the expert layers include bias terms. + + Returns: + `type[torch.nn.Module]`: The modified experts class. + """ + + def wrapper(experts_class: type[torch.nn.Module]) -> type[torch.nn.Module]: + original_init = experts_class.__init__ + original_forward = experts_class.forward + + @wraps(original_init) + def __init__(self, config, *args, **kwargs): + original_init(self, config, *args, **kwargs) + self.config = config + self.has_bias = has_bias + self.is_transposed = is_transposed + + @wraps(original_forward) + def forward(self, *args, **kwargs): + experts_forward = original_forward + + if self.config._experts_implementation != "eager": + experts_forward = ALL_EXPERTS_FUNCTIONS[self.config._experts_implementation] + + return experts_forward(self, *args, **kwargs) + + if not hasattr(experts_class, "_apply_gate"): + experts_class._apply_gate = _default_apply_gate + experts_class.__init__ = __init__ + experts_class.forward = forward + return experts_class + + if experts_class is not None: + return wrapper(experts_class) + + return wrapper diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 623f20f4c8cb..bab3a2a09220 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3586,7 +3586,15 @@ def get_init_context(cls, dtype: torch.dtype, is_quantized: bool, _is_ds_init_ca return init_contexts - def set_use_kernels(self, use_kernels, kernel_config): + def set_use_kernels(self, use_kernels, kernel_config: KernelConfig | None = None): + """ + Set whether or not to use the `kernels` library to kernelize some layers of the model. + Args: + use_kernels (`bool`): + Whether or not to use the `kernels` library to kernelize some layers of the model. + kernel_config (`KernelConfig`, *optional*): + The kernel configuration to use to kernelize the model. If `None`, the default kernel mapping will be used. + """ if use_kernels: if not is_kernels_available(): raise ValueError( diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 311b1baf15b5..56e894119b33 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -27,7 +27,7 @@ from ... import initialization as init from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin -from ...integrations import use_kernel_forward_from_hub, use_kernelized_func +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_layers import ( GenericForSequenceClassification, @@ -64,86 +64,52 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" +@use_experts_implementation(is_transposed=True, has_bias=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size - self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) - self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) - self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.intermediate_size)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_size)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.intermediate_size, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - """ - When training it is more efficient to just loop over the experts and compute the output for each expert - as otherwise the memory would explode. + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + return gated_output - For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. - - Args: - hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) - selected_experts (torch.Tensor): (batch_size * seq_len, top_k) - routing_weights (torch.Tensor): (batch_size * seq_len, top_k) - Returns: - torch.Tensor - """ - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - if hidden_states.device.type == "cpu" or self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=self.num_experts - ) # masking is also a class - expert_mask = expert_mask.permute(2, 1, 0) - # we sum on the top_k and on the sequence length to get which experts - # are hit this time around - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit[:]: - # expert_idx only have 1 element, so we can use scale for fast indexing - expert_idx = expert_idx[0] - # skip masking index - if expert_idx == self.num_experts: - continue - with torch.no_grad(): - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, top_k_pos, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states.view(batch_size, -1, self.hidden_size) - else: - num_tokens = hidden_states.shape[0] - hidden_states = hidden_states.repeat(self.num_experts, 1) - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) - next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size) - - full_routing_weights = torch.zeros( - num_tokens, self.num_experts, device=routing_weights.device, dtype=routing_weights.dtype - ) - full_routing_weights.scatter_(1, router_indices, routing_weights) - full_routing_weights = full_routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1, 1) + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=self.num_experts + ) # masking is also a class + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + # expert_idx only have 1 element, so we can use scale for fast indexing + expert_idx = expert_idx[0] + # skip masking index + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gated_output = self._apply_gate(gate_up) + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, top_k_pos, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states * full_routing_weights - next_states = next_states.sum(dim=0) return next_states @@ -157,11 +123,9 @@ def __init__(self, config): self.bias = nn.Parameter(torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = router_top_value + router_scores = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_logits, router_scores, router_indices @@ -173,9 +137,12 @@ def __init__(self, config): self.experts = GptOssExperts(config) def forward(self, hidden_states): + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_dim) _, router_scores, router_indices = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_indices, router_scores) - return routed_out, router_scores + hidden_states = self.experts(hidden_states, router_indices, router_scores) + hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return hidden_states, router_scores class GptOssRotaryEmbedding(nn.Module): diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 2db276752010..267d42bd9e20 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -19,7 +19,7 @@ from ... import initialization as init from ...cache_utils import Cache, DynamicCache -from ...integrations import use_kernel_forward_from_hub +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask from ...modeling_outputs import ( MoeModelOutputWithPast, @@ -61,86 +61,52 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) # main diff with Llama +@use_experts_implementation(is_transposed=True, has_bias=True) class GptOssExperts(nn.Module): def __init__(self, config): super().__init__() self.intermediate_size = config.intermediate_size self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size - self.expert_dim = self.intermediate_size - self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim)) - self.gate_up_proj_bias = nn.Parameter(torch.zeros(self.num_experts, 2 * self.expert_dim)) - self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.expert_dim, self.hidden_size))) - self.down_proj_bias = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size)) + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.intermediate_size)) + self.gate_up_proj_bias = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_size)) + self.down_proj = nn.Parameter(torch.empty((self.num_experts, self.intermediate_size, self.hidden_size))) + self.down_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.hidden_size)) self.alpha = 1.702 self.limit = 7.0 + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + return gated_output + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - """ - When training it is more efficient to just loop over the experts and compute the output for each expert - as otherwise the memory would explode. - - For inference we can sacrifice some memory and compute the output for all experts at once. By repeating the inputs. - - Args: - hidden_states (torch.Tensor): (batch_size, seq_len, hidden_size) - selected_experts (torch.Tensor): (batch_size * seq_len, top_k) - routing_weights (torch.Tensor): (batch_size * seq_len, top_k) - Returns: - torch.Tensor - """ - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - if hidden_states.device.type == "cpu" or self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=self.num_experts - ) # masking is also a class - expert_mask = expert_mask.permute(2, 1, 0) - # we sum on the top_k and on the sequence length to get which experts - # are hit this time around - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - for expert_idx in expert_hit[:]: - # expert_idx only have 1 element, so we can use scale for fast indexing - expert_idx = expert_idx[0] - # skip masking index - if expert_idx == self.num_experts: - continue - with torch.no_grad(): - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, top_k_pos, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states.view(batch_size, -1, self.hidden_size) - else: - num_tokens = hidden_states.shape[0] - hidden_states = hidden_states.repeat(self.num_experts, 1) - hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) - next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size) - - full_routing_weights = torch.zeros( - num_tokens, self.num_experts, device=routing_weights.device, dtype=routing_weights.dtype - ) - full_routing_weights.scatter_(1, router_indices, routing_weights) - full_routing_weights = full_routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1, 1) + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot( + router_indices, num_classes=self.num_experts + ) # masking is also a class + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + # expert_idx only have 1 element, so we can use scale for fast indexing + expert_idx = expert_idx[0] + # skip masking index + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gated_output = self._apply_gate(gate_up) + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, top_k_pos, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states * full_routing_weights - next_states = next_states.sum(dim=0) return next_states @@ -154,11 +120,9 @@ def __init__(self, config): self.bias = nn.Parameter(torch.zeros(self.num_experts)) def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight, self.bias) # (num_tokens, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (num_tokens, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = router_top_value + router_scores = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) return router_logits, router_scores, router_indices @@ -170,9 +134,12 @@ def __init__(self, config): self.experts = GptOssExperts(config) def forward(self, hidden_states): + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_dim) _, router_scores, router_indices = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_indices, router_scores) - return routed_out, router_scores + hidden_states = self.experts(hidden_states, router_indices, router_scores) + hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return hidden_states, router_scores class GptOssRotaryEmbedding(Qwen2RotaryEmbedding): diff --git a/tests/models/gpt_oss/test_modeling_gpt_oss.py b/tests/models/gpt_oss/test_modeling_gpt_oss.py index 772fd7f03013..23c92d900e35 100644 --- a/tests/models/gpt_oss/test_modeling_gpt_oss.py +++ b/tests/models/gpt_oss/test_modeling_gpt_oss.py @@ -120,6 +120,9 @@ def test_flex_attention_with_grads(self): def test_generate_compile_model_forward_fullgraph(self): return super().test_generate_compile_model_forward_fullgraph() + def test_reverse_loading_mapping(self, check_keys_were_modified=False): + super().test_reverse_loading_mapping(check_keys_were_modified) + RESULTS_PATH = Path(__file__).parent.parent.parent / "fixtures/gpt_oss/integration_tests.json" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2cf6715b94a0..0ea261430b0d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -559,13 +559,13 @@ def _test_eager_matches_batched_and_grouped_inference(self, name, dtype): for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() set_config_for_less_flaky_test(config) - model = model_class(config) + model = model_class(config).eval().to(torch_device).to(dtype) set_model_for_less_flaky_test(model) - # Load with dtype + # Reload to find any buffer misalignments after saving/loading with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, dtype=dtype).eval().to(torch_device) + model = model_class.from_pretrained(tmpdirname).eval().to(torch_device).to(dtype) with torch.no_grad(): inputs_dict = {k: v.to(dtype) if torch.is_floating_point(v) else v for k, v in inputs_dict.items()} @@ -838,6 +838,7 @@ def test_save_load(self): model = model_class.from_pretrained(tmpdirname) model.to(torch_device) + model.eval() with torch.no_grad(): second = model(**self._prepare_for_class(inputs_dict, model_class))[0] @@ -2687,6 +2688,11 @@ def test_multi_gpu_data_parallel_forward(self): model.to(0) model.eval() + if model.config._experts_implementation == "grouped_mm": + # DataParallel does not respect buffer alignment when replicating the model on + # multiple GPUs, which can cause errors in grouped_mm experts implementation. + model.set_experts_implementation("eager") + # Wrap model in nn.DataParallel model = nn.DataParallel(model) torch.cuda.synchronize() # otherwise the transfer might not be complete @@ -4460,7 +4466,7 @@ def test_tp_plan_matches_params(self): len(unused_entries) == 0, f"The following entries of the TP-plan are not valid: {unused_entries}" ) - def test_reverse_loading_mapping(self): + def test_reverse_loading_mapping(self, check_keys_were_modified=True): """Make sure we can load and save correctly the models having any weight renaming mapping or weight conversion mapping. Note that this test would be better if we could start from the serialized keys, and check that the model @@ -4469,6 +4475,11 @@ def test_reverse_loading_mapping(self): reverse the conversion and then check that those converted keys match correctly the conversions. However, all the checks performed here should ensure everything is going as it should. + + Args: + check_keys_were_modified (`bool`, *optional*, defaults to `True`): + Whether to expect keys being modified or not. In some cases, models do not change keys but + their weights, e.g. via transpose, memory alignment, etc. """ config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -4499,8 +4510,9 @@ def test_reverse_loading_mapping(self): # Get all the serialized keys that we just saved according to the reverse mapping serialized_keys = list(state_dict.keys()) - # They should be different, otherwise we did not perform any mapping - self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") + if check_keys_were_modified: + # They should be different, otherwise we did not perform any mapping + self.assertNotEqual(sorted(serialized_keys), sorted(model_keys), "No key mapping was performed!") # Check that for each conversion entry, we at least map to one key for conversion in conversions: