Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
57f7874 to
9607be8
Compare
SunMarc
left a comment
There was a problem hiding this comment.
A few nits but overall fine !
| source_keys: list[str], | ||
| target_keys: list[str], | ||
| full_layer_name: str, | ||
| model, | ||
| missing_keys, | ||
| config, | ||
| **kwargs, |
There was a problem hiding this comment.
most of the args should be optional kwargs so that we can clean the other convert function with **kwargs and only put args that are being used but that's fine, we should do that later on
| if self.pre_quantized: | ||
| return False |
There was a problem hiding this comment.
this is typically one of the cases where we would get the wrong numel calculation since we are skipping them. We should try to fix that at some point, as this should be quite simple
| WeightConverter( | ||
| source_keys=["weight:_data"], | ||
| target_keys="weight", | ||
| operations=[TorchAoDeserialize(self)], |
There was a problem hiding this comment.
A WeightRename should be enough in this case no ?
There was a problem hiding this comment.
yes both are fine i guess
There was a problem hiding this comment.
FYI we just changed weight:_data to weight_qdata so these things can be attached to module directly incase we need it in the future. pytorch/ao@ba3ac9f
There was a problem hiding this comment.
Weight converter is better than WeightRename here because there is an op!
| full_layer_name: str | None = None, | ||
| missing_keys=None, | ||
| **kwargs, | ||
| ) -> dict[str, torch.Tensor]: |
There was a problem hiding this comment.
the safe serialization don't work yet because of torchao, so it is fine to just clean a bit, we can come back to that later on
| # print("metadata", self.hf_quantizer.metadata) | ||
| raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") | ||
|
|
||
| new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name] |
There was a problem hiding this comment.
in a followup pr, we can modify this to work with all tensor subclasses and for sharded checkpoint files.
im thinking that in this convert function, we load in the tensor subclass components (ie _weight_qdata) as module parameters. after all files are loaded, we can delete them and replace the actual layer weights with the reconstructed quantized tensors.
see #41998 for details - will this approach still work with the new refactoring? cc @jerryzh168
There was a problem hiding this comment.
@liangel-02 yeah I think our original approach should still work, I guess it's fine to land this PR first and you can re-open #41998 on top of these new changes, since you are more familiar with this part
There was a problem hiding this comment.
Thanks both for chiming in! 🤗
| if self.pre_quantized: | ||
| return [ | ||
| WeightConverter( | ||
| source_keys=["weight:qdata", "weight:scale", "weight:zero_point"], |
There was a problem hiding this comment.
nit: maybe also add [weight_qdata, weight_scale] as well since zero_point may be optional, like https://github.com/pytorch/ao/blob/2ff1eb2e356275cfbe46832327387d382c72945d/torchao/quantization/quantize_/workflows/float8/float8_tensor.py#L99
There was a problem hiding this comment.
let's do that in a follow up pr since the safetensors support is broken with the latest torchao version
ArthurZucker
left a comment
There was a problem hiding this comment.
Great work! thanks 🤗
| # print("metadata", self.hf_quantizer.metadata) | ||
| raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") | ||
|
|
||
| new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name] |
There was a problem hiding this comment.
Thanks both for chiming in! 🤗
| source_keys: list[str], | ||
| target_keys: list[str], | ||
| full_layer_name: str, | ||
| model, | ||
| missing_keys, | ||
| config, | ||
| **kwargs, |
| if hf_quantizer is not None: | ||
| weight_conversions.extend(hf_quantizer.get_weight_conversions()) |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: finegrained_fp8, torchao_integration |
* inital commit * up * update unexpected later on * fix * update * simplify our lives * isolate a bit more * fixup * small nits * style * nit * fix common cases * fix post merge * bnb needs missing keys * small fix * bettrer documentation * no veradict + base class * rake review comments * take all comments * fix super init * update doc to be more real * up * fix some tests * weight convertor * up * mostly correct * oups * skip non linears * only some tests to go * need quantization * fix tests * rm comment * revert * revert 2 * style * up * up * remove unsafe loading path * fix * fix * fix * up * rm Dtensor import * rm replicate import * fix imports * up * minor modifications * add quantizaton_operation * delattr * fix * fix get_param_buffer * better to just set module initialized * rm tie_weights * guard imports * up * rm offloading for now * add license * don't guard torch * comment * fix * rm torch.grad * revert * fix * add guard * add second guard --------- Co-authored-by: Arthur <arthur.zucker@gmail.com> Co-authored-by: Marc Sun <marc@huggingface.co>
What does this PR do?
Refactors torchao quantization method to use conversion ops instead of the classical
create_quantized_param