- 
                Notifications
    You must be signed in to change notification settings 
- Fork 33
[Attention] Attention head quantization strategy #481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
bf00a99    to
    2ea692d      
    Compare
  
    97a4d16    to
    0fdfbd1      
    Compare
  
    2ea692d    to
    326f802      
    Compare
  
    0fdfbd1    to
    8973328      
    Compare
  
    70da261    to
    48875e2      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how are we expecting users to use QuantizationStrategy.ATTN_HEAD in a recipe? If i'm understanding correctly, it would look something like this?
quant_stage:
  quant_modifiers:
    QuantizationModifier:
      config_groups:
        group0:
          targets: ["re:.*self_attn$"]
          weights:
            strategy: attn_head
            ...
        group1:
          targets: ["re:.*(q|k|v)_proj$"]
          weights:
            strategy: group
            ...| @brian-dellabetta I’ve decided that giving per-attention strategy its own strategy (rather than reusing group) makes more sense. quant_stage:
  quant_modifiers:
    QuantizationModifier:
      config_groups:
        group0:
          targets: ["re:.*self_attn$"]
          input_activations:
            strategy: attn_head
            ... | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall format LGTM, but i'm struggling with understanding how we're arriving at some of these expected_shapes
48875e2    to
    e1ca4fd      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating!
The base branch was changed.
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
Signed-off-by: Kyle Sayers <[email protected]>
d084c5e    to
    e3f24d4      
    Compare
  
    Signed-off-by: Kyle Sayers <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Do we need logic somewhere to reverse flatten_attention_for_quantization? It seems like it would be important to be sure that the unflattening process is implemented in parallel with the flattening function.
| @fynnsu The inverse function would require extra metadata (for example, unflattening (batch_size * seq_len) requires knowing either batch_size or seq_len). Calibration only requires the forward function. Implementing the backwards function would allow us to share the util across calibration and quantization forward. This might be nice for standardization and potentially faster runtime, but isn't high priority right now. | 
Purpose
Given an attention state of shape
(batch_size, num_heads, seq_len, head_dim), the head attention strategy will generate scales of shape(num_heads, 1, 1).Prerequisites
Changes
Testing