-
Notifications
You must be signed in to change notification settings - Fork 473
Added GLM Modeling #2170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Added GLM Modeling #2170
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
7db5f4e
Added GLM Modeling.
phaelon74 e1e7b58
Adding the test file for GLM-MoE.
phaelon74 a7b0e8a
Adding Trust Remote Code for testing calibration.
phaelon74 f02b1be
Fixing the Loop Gemini identifed.
phaelon74 c908f77
Adding example script for GLM-4.7 Quanting.
phaelon74 1df156b
Merge branch 'main' into GLM-Modeling
dsikka d9913c9
Changing Ignore of first three dense layers to Regex matching pattern.
phaelon74 f4d1e95
Updating non dense layers to also use a Regex in the Ignore Section.
phaelon74 5429a71
Update the Test_Calib_glm4_moe.py with proper stub directory.
phaelon74 8caaf28
Updating the Example script to use argument paramaters at script laun…
phaelon74 81ce666
Merge branch 'main' into GLM-Modeling
dsikka 9d3618b
Address the items identified during quality check.
phaelon74 dfc779f
Merge branch 'main' into GLM-Modeling
phaelon74 166de90
Merge branch 'main' into GLM-Modeling
brian-dellabetta e4eed39
Merge branch 'main' into GLM-Modeling
dsikka 3ffa14d
Updating the order of datasets import.
phaelon74 1e200e9
Merge branch 'main' into GLM-Modeling
dsikka e2e1b77
Removed AWQMappings and utilized AWQModifier. Also updated auto_dtyp…
phaelon74 db7de05
Created a helper function for fixing the generation config.
phaelon74 9a7e115
Merge branch 'main' into GLM-Modeling
dsikka fdc5ffb
Merge branch 'main' into GLM-Modeling
dsikka d49fb67
Merge branch 'main' into GLM-Modeling
dsikka 19be730
Simplified the GLM4_7 Example script.
phaelon74 014cfea
Adding MOE_ignore layers back into Simplified Script.
phaelon74 4f0bb66
Fixing Import Order.
phaelon74 083c35b
Format
dsikka File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| from datasets import load_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modeling.glm4_moe import CalibrationGlm4MoeMoE # noqa: F401 | ||
| from llmcompressor.modifiers.awq import AWQModifier | ||
|
|
||
| # Load the model | ||
| model_id = "zai-org/GLM-4.7" | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
| # MoE calibration is now handled automatically by the pipeline. | ||
| # The `CalibrationGlm4MoeMoE` modules (from `llmcompressor.modeling.glm4_moe`) | ||
| # will be applied during calibration to enable proper expert calibration. | ||
| # These replace the original `Glm4MoeMoE` class from | ||
| # `transformers.models.glm4_moe.modeling_glm4_moe`. | ||
|
|
||
| # Select calibration dataset. | ||
| DATASET_ID = "HuggingFaceH4/ultrachat_200k" | ||
| DATASET_SPLIT = "train_sft" | ||
|
|
||
| # Select number of samples. 512 samples is a good place to start. | ||
| # Increasing the number of samples can improve accuracy. | ||
| NUM_CALIBRATION_SAMPLES = 512 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
|
|
||
| # Load dataset and preprocess. | ||
| ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") | ||
| ds = ds.shuffle(seed=42) | ||
|
|
||
|
|
||
| def preprocess(example): | ||
| return { | ||
| "text": tokenizer.apply_chat_template( | ||
| example["messages"], | ||
| tokenize=False, | ||
| ) | ||
| } | ||
|
|
||
|
|
||
| ds = ds.map(preprocess) | ||
|
|
||
|
|
||
| # Tokenize inputs. | ||
| def tokenize(sample): | ||
| return tokenizer( | ||
| sample["text"], | ||
| padding=False, | ||
| max_length=MAX_SEQUENCE_LENGTH, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
|
|
||
| ds = ds.map(tokenize, remove_columns=ds.column_names) | ||
|
|
||
| moe_ignores = [ | ||
| # Layers 0-2: Dense layer - ignore entire layers | ||
| "model.layers.0.*", | ||
| "model.layers.1.*", | ||
| "model.layers.2.*", | ||
| # Ignore the output head | ||
| "lm_head", | ||
| ] | ||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # * quantize the weights to 4 bit with GPTQ with a group size 128 | ||
| recipe = AWQModifier(targets="Linear", scheme="W4A16", ignore=moe_ignores) | ||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128" | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| import torch | ||
| from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig | ||
| from transformers.models.glm4_moe.modeling_glm4_moe import ( | ||
| Glm4MoeMoE as OriginalGlm4MoeMoE, | ||
| ) | ||
|
|
||
| from llmcompressor.modeling.moe_context import MoECalibrationModule | ||
|
|
||
|
|
||
| @MoECalibrationModule.register("Glm4MoeMoE") | ||
| class CalibrationGlm4MoeMoE(MoECalibrationModule): | ||
phaelon74 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Calibration version of Glm4MoeMoE that sends all tokens to all experts. | ||
| During calibration, when calibrate_all_experts=True, all tokens are sent to | ||
| all experts to ensure proper quantization statistics are collected for every | ||
| expert, not just those activated by the calibration data routing. | ||
| """ | ||
|
|
||
| is_permanent = False | ||
|
|
||
| def __init__( | ||
| self, | ||
| original: OriginalGlm4MoeMoE, | ||
| config: Glm4MoeConfig, | ||
| calibrate_all_experts: bool = True, | ||
| ): | ||
| super().__init__() | ||
| self.config = config | ||
| self.experts = original.experts | ||
| self.gate = original.gate | ||
| self.shared_experts = original.shared_experts | ||
| self.calibrate_all_experts = calibrate_all_experts | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Forward pass with optional calibration mode. | ||
| When calibrate_all_experts=True: | ||
| - All tokens are sent to all experts for calibration | ||
| - Routing weights are still used for final output combination | ||
| - This ensures all experts see calibration data | ||
| When calibrate_all_experts=False: | ||
| - Normal MoE routing behavior (only routed tokens go to each expert) | ||
| """ | ||
| residuals = hidden_states | ||
| orig_shape = hidden_states.shape | ||
| topk_indices, topk_weights = self.gate(hidden_states) | ||
| hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) | ||
|
|
||
| # Begin MoE - inline the moe() method logic with calibration support | ||
| final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype) | ||
| expert_mask = torch.nn.functional.one_hot( | ||
| topk_indices, num_classes=len(self.experts) | ||
| ) | ||
| expert_mask = expert_mask.permute(2, 0, 1) | ||
|
|
||
| for expert_idx, expert in enumerate(self.experts): | ||
| mask = expert_mask[expert_idx] | ||
| token_indices, weight_indices = torch.where(mask) | ||
| has_tokens = token_indices.numel() > 0 | ||
|
|
||
| if self.calibrate_all_experts: | ||
| # When calibrating, run all tokens through the expert to gather stats. | ||
| # The output is still calculated using only the routed tokens. | ||
| expert_output_full = expert(hidden_states) | ||
| if not has_tokens: | ||
| # No tokens routed to this expert, but stats were gathered. | ||
| continue | ||
| expert_output = expert_output_full[token_indices] | ||
| else: | ||
| # Standard MoE behavior: only process tokens routed to this expert. | ||
| if not has_tokens: | ||
| continue | ||
| expert_output = expert(hidden_states[token_indices]) | ||
|
|
||
| # Common logic for combining expert outputs | ||
| expert_weights = topk_weights[token_indices, weight_indices] | ||
| weighted_output = expert_output * expert_weights.unsqueeze(-1) | ||
| final_hidden_states.index_add_(0, token_indices, weighted_output) | ||
| # End MoE | ||
|
|
||
| hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) | ||
| hidden_states = hidden_states + self.shared_experts(residuals) | ||
| return hidden_states | ||
|
|
||
| def restore(self, original: torch.nn.Module) -> torch.nn.Module: | ||
| """ | ||
| Restore the original module structure. | ||
|
|
||
| Since is_permanent=False, this method is called when exiting | ||
| the calibration context to restore the original MoE module. | ||
| """ | ||
| return original | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| import contextlib | ||
| from functools import partial | ||
|
|
||
| import pytest | ||
| import torch | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| from llmcompressor.modeling.glm4_moe import CalibrationGlm4MoeMoE | ||
| from llmcompressor.modeling.moe_context import moe_calibration_context | ||
| from llmcompressor.utils.dev import skip_weights_download | ||
| from llmcompressor.utils.helpers import calibration_forward_context | ||
| from tests.testing_utils import requires_cadence, requires_gpu | ||
|
|
||
| Glm4MoeConfig = pytest.importorskip( | ||
| "transformers.models.glm4_moe.configuration_glm4_moe", | ||
| reason="Glm4MoeConfig not available in this version of transformers", | ||
| ).Glm4MoeConfig | ||
| OriginalGlm4MoeMoE = pytest.importorskip( | ||
| "transformers.models.glm4_moe.modeling_glm4_moe", | ||
| reason="Glm4MoeMoE not available in this version of transformers", | ||
| ).Glm4MoeMoE | ||
|
|
||
|
|
||
| @requires_cadence("weekly") | ||
| @pytest.mark.parametrize("model_stub", ["zai-org/GLM-4.7"]) | ||
| def test_calib_replace_glm4moe_all_experts(model_stub): | ||
| with skip_weights_download(): | ||
| model = AutoModelForCausalLM.from_pretrained(model_stub, trust_remote_code=True) | ||
|
|
||
| with contextlib.ExitStack() as stack: | ||
| stack.enter_context(calibration_forward_context(model)) | ||
| stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) | ||
|
|
||
| # Find a GLM4 MoE layer | ||
| moe_layer = None | ||
| for _, module in model.named_modules(): | ||
| if isinstance(module, CalibrationGlm4MoeMoE): | ||
| moe_layer = module | ||
| break | ||
|
|
||
| assert moe_layer is not None | ||
|
|
||
| num_experts = len(moe_layer.experts) | ||
| expert_triggered = [False for _ in range(num_experts)] | ||
|
|
||
| # Define the hook function | ||
| def hook_fn(i, module, input, output): | ||
| expert_triggered[i] = True | ||
|
|
||
| # Attach hooks using functools.partial to bind each index | ||
| for i, expert in enumerate(moe_layer.experts): | ||
| expert.register_forward_hook(partial(hook_fn, i)) | ||
|
|
||
| # Create dummy input tensor that simulates hidden_states | ||
| hidden_dim = model.config.hidden_size | ||
| batch, seq_len = 4, 32 | ||
| sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) | ||
|
|
||
| # Forward through the MoE layer directly | ||
| with torch.no_grad(): | ||
| _ = moe_layer(sample) | ||
|
|
||
| # Assert all experts are used | ||
| assert all( | ||
| expert_triggered | ||
| ), f"Not all experts were triggered: {expert_triggered}" | ||
|
|
||
|
|
||
| @requires_gpu | ||
| def test_calib_glm4moe_module(): | ||
| config = Glm4MoeConfig() | ||
| with torch.device("cuda"): | ||
| original = OriginalGlm4MoeMoE(config).eval() | ||
|
|
||
| # Create dummy input tensor that simulates hidden_states | ||
| hidden_dim = config.hidden_size | ||
| batch, seq_len = 4, 32 | ||
| sample = torch.randn(batch, seq_len, hidden_dim, device="cuda") | ||
|
|
||
| with calibration_forward_context(original): | ||
| true_output = original(sample) | ||
|
|
||
| module = CalibrationGlm4MoeMoE(original, config, calibrate_all_experts=True) | ||
| with calibration_forward_context(module): | ||
| output = module(sample) | ||
| assert torch.allclose(true_output, output, atol=1e-6) | ||
|
|
||
| module = CalibrationGlm4MoeMoE(original, config, calibrate_all_experts=False) | ||
| with calibration_forward_context(module): | ||
| output = module(sample) | ||
| assert torch.allclose(true_output, output, atol=1e-6) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.