-
Notifications
You must be signed in to change notification settings - Fork 467
Enhance Autoround to support multiple cards tuning #2157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
bc97d48
update autoround version
yiliu30 19ab4f2
Merge branch 'main' into autoround-version
yiliu30 9ba113c
expose bs
yiliu30 646982a
Merge branch 'autoround-version' of https://github.com/yiliu30/llm-co…
yiliu30 1050335
use 0.9.1
yiliu30 50e6682
fix
yiliu30 d139071
update
yiliu30 a0affbd
enable auto-dispatch
yiliu30 17ba9f5
add ds example
yiliu30 cd943cd
merge main
yiliu30 8338ed5
pass ignore to ar
yiliu30 56515af
add qwen example
yiliu30 ad6c1c0
update example
yiliu30 09a72c0
format
yiliu30 af112bd
update
yiliu30 ec98118
refine suspend hook
yiliu30 c5eae60
update
yiliu30 2d482fc
clean code
yiliu30 17b7e45
add ut
yiliu30 7a9b3cd
fix
yiliu30 4f45b17
fix hint
yiliu30 0fac601
refine
yiliu30 0f7a990
speedup ut
yiliu30 58ef017
clean
yiliu30 c9ea99c
add docstring
yiliu30 d2a7c92
format
yiliu30 d48c3d6
Merge branch 'main' into auto-device
yiliu30 993a68e
Merge branch 'main' into auto-device
yiliu30 fa8cdcc
Merge branch 'main' into auto-device
yiliu30 c17e923
rename device_map to device_ids
yiliu30 1092cde
fix typo
yiliu30 0734dd5
add docstring
yiliu30 d3e6da6
Merge branch 'main' into auto-device
yiliu30 6d4934a
fix format issue
yiliu30 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| from auto_round.calib_dataset import get_dataset | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| from llmcompressor import oneshot | ||
| from llmcompressor.modifiers.autoround import AutoRoundModifier | ||
| from llmcompressor.utils import dispatch_for_generation | ||
|
|
||
| # Select model and load it. | ||
| model_id = "Qwen/Qwen3-235B-A22B" | ||
| model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
|
|
||
| # Select calibration dataset. | ||
| NUM_CALIBRATION_SAMPLES = 128 | ||
| MAX_SEQUENCE_LENGTH = 2048 | ||
| ITERS = 200 | ||
| # Get aligned calibration dataset. | ||
|
|
||
| ds = get_dataset( | ||
| tokenizer=tokenizer, | ||
| seqlen=MAX_SEQUENCE_LENGTH, | ||
| nsamples=NUM_CALIBRATION_SAMPLES, | ||
| ) | ||
|
|
||
|
|
||
| # Configure the quantization algorithm to run. | ||
| # * quantize the weights to 4 bit with AutoRound with a group size 128 | ||
| # * For `Qwen/Qwen3-235B-A22B`, it requires about 300 GB memory | ||
| # to run tuning with default settings. | ||
| recipe = AutoRoundModifier( | ||
| targets="Linear", | ||
| scheme="W4A16", | ||
| ignore=[ | ||
| "lm_head", | ||
| "re:.*mlp.gate$", | ||
| ], | ||
| iters=ITERS, | ||
| enable_torch_compile=False, | ||
| device_ids="0,1,2,3", # Use 4 A100 GPUs | ||
| ) | ||
|
|
||
|
|
||
| # Apply algorithms. | ||
| oneshot( | ||
| model=model, | ||
| dataset=ds, | ||
| recipe=recipe, | ||
| max_seq_length=MAX_SEQUENCE_LENGTH, | ||
| num_calibration_samples=NUM_CALIBRATION_SAMPLES, | ||
| shuffle_calibration_samples=False, | ||
| ) | ||
|
|
||
|
|
||
| # Save to disk compressed. | ||
| SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128-AutoRound" | ||
| print(f"save to {SAVE_DIR}") | ||
| model.save_pretrained(SAVE_DIR, save_compressed=True) | ||
| tokenizer.save_pretrained(SAVE_DIR) | ||
|
|
||
|
|
||
| # Confirm generations of the quantized model look sane. | ||
| print("\n\n") | ||
| print("========== SAMPLE GENERATION ==============") | ||
| dispatch_for_generation(model) | ||
| sample = tokenizer("Hello my name is", return_tensors="pt") | ||
| sample = {key: value.to(model.device) for key, value in sample.items()} | ||
| output = model.generate(**sample, max_new_tokens=100) | ||
| print(tokenizer.decode(output[0])) | ||
| print("==========================================\n\n") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.