Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

[1/N] Rs/vllm quantization - Refactor to minimize llama.py changes#186

Merged
varun-sundar-rabindranath merged 12 commits intovllm-quantizationfrom
rs/vllm-quantization
Apr 16, 2024
Merged

[1/N] Rs/vllm quantization - Refactor to minimize llama.py changes#186
varun-sundar-rabindranath merged 12 commits intovllm-quantizationfrom
rs/vllm-quantization

Conversation

@robertgshaw2-redhat
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat commented Apr 13, 2024

Paired with @dsikka to refactor SmoothQuantLinearMethod to avoid making changes to llama.py

  • Removed all the "layer specific" SmoothQuantLinearMethod by making the indexing (splitting QKV into logical shards generic and explicitly handling state_dict converion
  • Successfully whittled down to only add one LOC to llama.py

Many todos left, including:

  • We currently have hardcoded use_per_token, need to use the quant config for this
  • We need a way to pass different quantconfigs to each layer to support nonuniform quantization

@robertgshaw2-redhat robertgshaw2-redhat changed the title Rs/vllm quantization [1/N] Rs/vllm quantization Apr 14, 2024
@robertgshaw2-redhat robertgshaw2-redhat changed the title [1/N] Rs/vllm quantization [1/N] Rs/vllm quantization - Refactor To Remove SQLinearMethod Variants Apr 14, 2024
@robertgshaw2-redhat robertgshaw2-redhat changed the title [1/N] Rs/vllm quantization - Refactor To Remove SQLinearMethod Variants [1/N] Rs/vllm quantization - Refactor to minimize llama.py changes Apr 14, 2024
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
params_dtype: torch.dtype, logical_widths: Optional[List[int]]) -> Dict[str, Any]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lift this to be inside LinearMethodBase ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got rid of this on the next pr

@@ -1,8 +1,9 @@
from typing import Any, Dict, List, Tuple, Optional

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice cleanup 🙌

if _is_support_smoothquant(model_config):
model = model_class(model_config.hf_config, linear_method,
quant_config)
model = model_class(model_config.hf_config, linear_method)
Copy link

@varun-sundar-rabindranath varun-sundar-rabindranath Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come we don't have to pass in the quant_config ? because the linearmethod already knows if it is quantized ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah linear method handles it

@varun-sundar-rabindranath

LGTM.

… via config (#188)

Refactored to support nonuniform quantization by adding a new layer of
Abstraction.

Now, `SmoothQuantLinearMethod` can hold a `SmoothQuantFormat`, which
implements the details of how to do quant and dequant operations. There
are two `SmoothQuantFormat` classes:
- `SmoothQuantDynamicPerToken`
- `SmoothQuantStaticPerTensor`

We have the following lifecycle:
- `LinearMethod` is created during `get_model`, has access to
`QuantizationConfig`
- `Layer` is initialized and passed a `LinearMethod`
- `Layer` calls `LinearMethod.create_weights`, which creates a
dictionary of weights and metadata
- `Layer` calls `LinearMethod.apply_weights` during inference, passing
the dictionary created during `create_weights`

This PR modifies the `LinearMethod.create_weights` API to receive a
`layer_name` as argument. The `LinearMethod` then looks in the `config`
to determine which `SmoothQuantFormat` to use for the layer with
`layer_name`
- As a result, the `LinearMethod` is responsible for parsing the config
from disk and making decisions about what the inference format should
look like. In this specific case, since the `SmoothQuantConfig` is not
very good, we just match on the suffix `qkv` to determine what each
layer should use --> but for SparseMLConfig, we could use a similar
structure

In this PR, the `SmoothQuantFormat` is passed in the dictionary returned
by `create_weights` and then is used by `apply_weights`


### In Summary

I think this is a good overall structure because it:
- (a) allows us to make minimal changes to the existing models
- (b) allows us to make no changes to the model loading lifecycle (i.e.
config / constructor / linear method) ** critically requires having one
LinearMethod that propagates through the whole model
- (c) encapsulates the nonuniform logic into the `LinearMethod`,
allowing us to have a clean interface into

### For SparseML Models

We could imagine the following architecture:

#### Config
Config is responsible for:
- loading config from disk
- mapping layer_names --> `SparseMLFormat`

```python
class SparseMLConfig
    def from_dict()
    
    def get_layer_format(layer_name):
         return SparseMLFormat
```

#### LinearMethod
Config is responsible for:
- interface between layers and kernels (so LinearMethod is what is used
by the model)

```python
class SparseMLLinearMethod:
    def __init__(self, sparseml_config)
          self.sparseml_config = sparseml_config
          
    def create_weights(layer_name, ...):
          # this, e.g. is where nonuniform might be supported
          format = self.sparseml_config.get_layer_format(layer_name)
          
          weights = format.get_weights()
          weights["format"] = format
          
          return weights
     
     # wrapper around the SparseML format
     def apply_weights(x, weights, ...)
           format = weights["format"]
           weights = weights["weights"]
           
           return format.apply_weights(x, weights)
```

#### SparseMLFormat
Format is responsible for:
- actual weight creation and forward

```python
class SparseMLLinearMethod:
    def __init__(self, sparseml_config)
          self.sparseml_config = sparseml_config
          
    def get_weights(sizes):
         # returns dictionary , e.g.
         return {
             "weights": x
             "scales": y
         }
     
     def apply_weights(weights, x):
         # calls cuda kernel 
         return output
```

Sample Formats:
    - `W8A8DynamicPerToken`
    - `SparseW8A8StaticPerTensorAsymmetric`
    - `W4A8DynamicPerToken`
    - ...
@varun-sundar-rabindranath varun-sundar-rabindranath merged commit 3e7d1c8 into vllm-quantization Apr 16, 2024
@varun-sundar-rabindranath varun-sundar-rabindranath deleted the rs/vllm-quantization branch April 16, 2024 17:43
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants