Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autoquant v2 initial version #1240

Merged
merged 15 commits into from
Nov 21, 2024
Merged

Autoquant v2 initial version #1240

merged 15 commits into from
Nov 21, 2024

Conversation

jerryzh168
Copy link
Contributor

@jerryzh168 jerryzh168 commented Nov 8, 2024

Summary:
We refactored the v1 to do benchmark for subgraphs of (prev_op -> linear -> post_op) in order to get more accurate estimation of timing. One issue here is now we need to care about batch size of the subgraph, so we'd need the batch size dimension to use symbolic shape, seems that it does not have good support on torch.compile right now

Current Status:

  • tested with llama2 and sam
  • llama2 has the same result as autoquant v1, for both default qtensor subclass list and the one that contains int4
  • sam get some speedup over v1 because it picked a int8dyn layer, while autoquant v1 picks float for everything

More improvements:

  • current batch size adjustment code is hardcoded to work for llama model, need to think of a way to generalize it
  • current we use GraphModule as key and comparing equality of graphs to avoid duplicated benchmarking effort, we have a naive graph equality check function, which should work reasonably well, but we could improve this by using subgraph matcher or canonicalize graph (if the tool exist)
  • add accuracy sanity checks
  • apply to more models
  • fqn from named_modules does not match extracted fqn (in dynamo tracking stack) in some torchbench models, but we can fix this when it appears in the models we care about.
  • I also heard from Animesh that modules are inlined by default now? and we should rely on node.meta for tracking where the nodes comes from and extract subgraph, we can revisit as well.

Test Plan:
Testing with torchao/_models/llama/generate.py

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant_v2-int4

<style type="text/css"></style>

llama2 autoquant_v1 autoquant_v2  
default qtensor list Average tokens/sec: 172.59Average Bandwidth: 1165.98 GB/sPeak Memory Usage: 8.65 GBModel Size: 6.76 GBhttps://www.internalfb.com/phabricator/paste/view/P1680313475 Average tokens/sec: 173.82Average Bandwidth: 1174.28 GB/sPeak Memory Usage: 9.61 GBModel Size: 6.76 GBhttps://www.internalfb.com/phabricator/paste/view/P1680309154  
int4 qtensor list Average tokens/sec: 208.69Average Bandwidth: 807.64 GB/sPeak Memory Usage: 4.53 GBModel Size: 3.87 GBhttps://www.internalfb.com/phabricator/paste/view/P1680316118 Average tokens/sec: 209.08Average Bandwidth: 809.15 GB/sPeak Memory Usage: 5.09 GBModel Size: 3.87 GBhttps://www.internalfb.com/phabricator/paste/view/P1680296091  
       
sam - image_encoder base autoquant_v1 (all float) autoquant_v2 (one layer picked int8dyn)
default qtensor list 23.18455616 23.09307945 24.19601284
  cuda,vit_h,32,13678,16,23.18455616196341,43.132160607870524,0.5811261131824416,max-autotune,torch.bfloat16,None,False,True,True,32,154,4928,None,None cuda,vit_h,32,13681,16,23.093079445154853,43.30301648920233,0.5811880744669748,max-autotune,torch.bfloat16,autoquant,False,True,True,32,154,4928,None,None cuda,vit_h,32,56597,69,24.196012837520133,41.32912338554085,0.5827170017253231,max-autotune,torch.bfloat16,autoquant_v2,False,True,True,32,154,4928,None,None

Reviewers:

Subscribers:

Tasks:

Tags:

Copy link

pytorch-bot bot commented Nov 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1240

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 0019456 with merge base d224653 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 8, 2024
torch.nn.Linear(*new_shape, dtype=weight_val.dtype),
).cuda()

else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file has some complexity for extracting (prev_op -> linear1 -> maybe_linear_2 -> next_ops) because the models we originally studied had back to back linears. If you only care about transformer models, you can simplify this code quite a bit by removing the special logic for extraction of the second linear. Happy to point to the right places in the code if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me clean this up a bit later since I'm still not sure if we need to reimplement the functionality with some other approaches yet, will figure out as we expand to test on more models

return True
return False

def debug_single_linear(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depending on what you're using this file for, this function also might be deleteable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah will refine more when it's closer to land, right now just experimenting to see if this approach helps improve things over the original approach on models we care about

@jerryzh168 jerryzh168 marked this pull request as draft November 12, 2024 00:35
@jerryzh168 jerryzh168 marked this pull request as ready for review November 16, 2024 02:59
@jerryzh168 jerryzh168 added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Nov 16, 2024
@drisspg
Copy link
Contributor

drisspg commented Nov 20, 2024

One overall nit is that since this seems like a prototype that will eventually back the main autoquant API, we should probably put this in the prototype folder until we're ready to move.

@jerryzh168
Copy link
Contributor Author

One overall nit is that since this seems like a prototype that will eventually back the main autoquant API, we should probably put this in the prototype folder until we're ready to move.

oh OK makes sense, I can move it

Summary:
We refactored the v1 to do benchmark for subgraphs of (prev_op -> linear -> post_op) in order to get more accurate estimation
of timing. One issue here is now we need to care about batch size of the subgraph, so we'd need the batch size dimension to use symbolic
shape, seems that it does not have good support on torch.compile right now

More improvements:
* current batch size adjustment code is hardcoded to work for llama model, need to think of a way to generalize it
* using canonicalized subgraph as key for the cache to reduce the number of times we need to do benchmarking
* add accuracy sanity checks

Test Plan:
Testing with torchao/_models/llama/generate.py

```
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant_v2-int4
```

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168
Copy link
Contributor Author

thanks @drisspg @vkuzo for the review, I have addressed the comments, please take a look again

@jerryzh168 jerryzh168 merged commit 7446433 into pytorch:main Nov 21, 2024
18 checks passed
@jerryzh168 jerryzh168 deleted the autoquant-v2 branch November 21, 2024 22:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants