-
Notifications
You must be signed in to change notification settings - Fork 450
Implement nas.convert() api for the compress algorithm #482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 48 commits
Commits
Show all changes
52 commits
Select commit
Hold shift + click to select a range
c758ad5
The main compression function for a model
danielkorzekwa 8af9903
Code formatting
danielkorzekwa 5ba6c27
Model search space configuration used by test_compress.py test.
danielkorzekwa 0bc5d84
Tokenizer used by test_compress.py test.
danielkorzekwa 87d4fa5
Tokenizer utility used by test_compress.py test
danielkorzekwa ced1e99
e2e tests for compress.py
danielkorzekwa 5de0bdc
Add convert_llama3_config_to_decilm_config + unit test
danielkorzekwa 800414c
Remove unused bypass distillation config files.
danielkorzekwa 16abcc9
Moving integration tests to tests/experimental to not trigger CICD
danielkorzekwa a5ba1c7
update docs
danielkorzekwa 1bda391
Replace mprint with print and replace osp.join with path1 / path2 not…
danielkorzekwa bb38401
Refactor file checking assertions to use .is_file() and .exists()
danielkorzekwa 8415548
Add a new dependency section to setyp.py for the modelopt.torch._comp…
danielkorzekwa b1b1833
Move test_convert_llama3_config_to_decilm_config.py to tests/experime…
danielkorzekwa d4ffc91
Merge branch 'feature/compress' into dkorzekwa/e2e_compression_test
kevalmorabia97 6f28e4a
Fix: Add missing LICENSE headers
kevalmorabia97 016fb63
Use spawn_multiprocess_job for test_compress test (to be able to use …
danielkorzekwa 0ccf1c4
Add comments.
danielkorzekwa 58439ca
Add _save_dummy_dataset to the test_compress.py
danielkorzekwa 2e5f776
Refactoring: Move torch distributed env variables to dist_utils.py
danielkorzekwa 6274db5
Refactoring: move torch distributed variables to dist_utils
danielkorzekwa d942e0a
Move os.environ["WANDB_DISABLED"] = "true" to dist_utils.py
danielkorzekwa f765921
Implement integration test for mnt.convert() for the _compress algori…
danielkorzekwa de876d6
Implement mtn.convert() for compress() algorithm.
danielkorzekwa 72bdc7a
Merge branch 'dkorzekwa/e2e_compression_test' into dkorzekwa/llama3_t…
danielkorzekwa 40d28af
Merge branch 'dkorzekwa/llama3_to_decilm_convertion' into dkorzekwa/n…
danielkorzekwa f7fe23c
Fix broken test - incorrect package names.
danielkorzekwa 3d1d286
Merge branch 'dkorzekwa/llama3_to_decilm_convertion' into dkorzekwa/n…
danielkorzekwa a210483
Implementing nas.convert for compress algorithm.
danielkorzekwa 739f868
Improve docs
danielkorzekwa b06d22b
Merge branch 'dkorzekwa/e2e_compression_test' into dkorzekwa/llama3_t…
danielkorzekwa 9352978
Merge branch 'dkorzekwa/llama3_to_decilm_convertion' into dkorzekwa/n…
danielkorzekwa 20a3c5e
Code cleanup.
danielkorzekwa 18cb88b
Merge branch 'feature/compress' into dkorzekwa/llama3_to_decilm_conve…
danielkorzekwa 1033c81
Fix import
danielkorzekwa 0680c45
simplify code
danielkorzekwa 2d9da30
implementing compress_nas_plugin
danielkorzekwa febab44
code clean up.
danielkorzekwa 86bf394
code clean up
danielkorzekwa 86e04a0
create conftest.py with shared test logic for compress tests.
danielkorzekwa ae61644
code cleanup
danielkorzekwa 2998cdb
Merge branch 'dkorzekwa/llama3_to_decilm_convertion' into dkorzekwa/n…
danielkorzekwa 3778ec2
code refactoring
danielkorzekwa d940000
refactoring
danielkorzekwa 0bf9a92
move test utilities from conftest.py to test_utils.py
danielkorzekwa b56df9a
Improve comments
danielkorzekwa fd63130
Merge branch 'feature/compress' into dkorzekwa/nas_convert
danielkorzekwa 9bfcc21
Added TODO.
danielkorzekwa 6504c44
Utilitities for hydra initialization
danielkorzekwa d0fb8f9
Code refactoring
danielkorzekwa 40f18b2
code refactoring
danielkorzekwa 936556f
Add compress dependencies to setup.py.
danielkorzekwa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
169 changes: 169 additions & 0 deletions
169
modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,169 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| Compress NAS plugin for the Modelopt framework (based on Puzzle algorithm: https://arxiv.org/abs/2411.19146). | ||
| """ | ||
|
|
||
| import datetime | ||
| from pathlib import Path | ||
|
|
||
| import pruning_ckpts | ||
| import score_pruning_activations | ||
| import torch | ||
| from scripts.convert_llama3_to_decilm import convert_llama3_to_decilm | ||
| from torch import nn | ||
|
|
||
| from modelopt.torch._compress.runtime import NativeDdpRuntime | ||
| from modelopt.torch.nas.conversion import NASModeRegistry | ||
| from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField | ||
| from modelopt.torch.opt.mode import ( | ||
| ConvertEntrypoint, | ||
| ConvertReturnType, | ||
| MetadataDict, | ||
| ModeDescriptor, | ||
| RestoreEntrypoint, | ||
| ) | ||
| from modelopt.torch.opt.searcher import BaseSearcher | ||
|
|
||
| # TODO Move initialize_hydra_config_for_dir from tests to main | ||
| from tests.utils.test_utils import initialize_hydra_config_for_dir | ||
|
|
||
|
|
||
| class CompressModel(nn.Module): | ||
| pass # No model implementation is needed for the compress mode | ||
|
|
||
|
|
||
| class CompressConfig(ModeloptBaseConfig): | ||
| """Configuration for Compress NAS algorithm.""" | ||
|
|
||
| # Input model path to compress in the HF format | ||
| input_model_path: str = ModeloptField( | ||
| default="", | ||
| title="", | ||
| description="", | ||
| ) | ||
|
|
||
| # Hydra config directory containing the search space definition | ||
| hydra_config_dir: str = ModeloptField( | ||
| default="", | ||
| title="", | ||
| description="", | ||
| ) | ||
|
|
||
| # Hydra config name containing the search space definition | ||
| hydra_config_name: str = ModeloptField( | ||
| default="", | ||
| title="", | ||
| description="", | ||
| ) | ||
|
|
||
| # Directory to save the compressed model and intermediate results | ||
| puzzle_dir: str = ModeloptField( | ||
| default="", | ||
| title="", | ||
| description="", | ||
| ) | ||
|
|
||
| # Dataset path to use for scoring in prunining and NAS search | ||
| dataset_path: str = ModeloptField( | ||
| default="", | ||
| title="", | ||
| description="", | ||
| ) | ||
|
|
||
|
|
||
| def convert_compress_model(model: nn.Module, config: CompressConfig) -> ConvertReturnType: | ||
| """1. Convert the model from HF format to DeciLM format. | ||
| 2. Score the pruning activations. | ||
| 3. Prune the model and save pruned checkpoints | ||
|
|
||
| The output of this step will be used by mnt.search() to perform the NAS search. | ||
| """ | ||
| runtime = NativeDdpRuntime( | ||
| dtype=torch.bfloat16, torch_distributed_timeout=datetime.timedelta(10) | ||
| ) | ||
|
|
||
| # Load hydra config | ||
| hydra_cfg = initialize_hydra_config_for_dir( | ||
| config_dir=config.hydra_config_dir, | ||
| config_name=config.hydra_config_name, | ||
| overrides=[ | ||
| f"puzzle_dir={config.puzzle_dir}", | ||
| f"dataset_path={config.dataset_path}", | ||
| ], | ||
| ) | ||
|
|
||
| # Convert Llama3 model to DeciLM model | ||
| hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable | ||
| convert_llama3_to_decilm( | ||
| input_dir=config.input_model_path, | ||
| output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, | ||
| ) | ||
|
|
||
| # Score_pruning_activations (distributed processing) | ||
| score_pruning_activations.launch_score_activations(hydra_cfg, runtime) | ||
|
|
||
| # Prune the model and save pruned checkpoints | ||
| if runtime.global_rank == 0: | ||
| pruning_ckpts.launch_prune_ckpt(hydra_cfg) | ||
| runtime.wait_for_everyone() | ||
|
|
||
| return model, {} | ||
|
|
||
|
|
||
| def restore_compress_model( | ||
| model: nn.Module, config: CompressConfig, metadata: MetadataDict | ||
| ) -> nn.Module: | ||
| """Restore is not needed for the compress mode as we are not saving any model state""" | ||
| return model | ||
|
|
||
|
|
||
| @NASModeRegistry.register_mode | ||
| class CompressDescriptor(ModeDescriptor): | ||
| """Descriptor for the Compress mode.""" | ||
|
|
||
| @property | ||
| def name(self) -> str: | ||
| """String identifier for this mode.""" | ||
| return "compress" | ||
|
|
||
| @property | ||
| def config_class(self) -> type[ModeloptBaseConfig]: | ||
| """Configuration class for this mode.""" | ||
| return CompressConfig | ||
|
|
||
| @property | ||
| def search_algorithm(self) -> type[BaseSearcher]: | ||
| """Return the associated searcher implementation.""" | ||
| raise NotImplementedError("Compress mode does not have a search algorithm yet.") | ||
|
|
||
| @property | ||
| def convert(self) -> ConvertEntrypoint: | ||
| """Entrypoint to convert a model.""" | ||
| return convert_compress_model | ||
|
|
||
| @property | ||
| def restore(self) -> RestoreEntrypoint: | ||
| """Entrypoint to restore a model.""" | ||
| return restore_compress_model | ||
|
|
||
| @property | ||
| def export_mode(self) -> str | None: | ||
| """The mode that corresponds to the export mode. | ||
| For now, this will be a no-op as there is no modelopt's concept of search space defined | ||
| for the compress algorithm. | ||
| """ | ||
| return "export_nas" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.