You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
float8 inference originally started in torchao.float8 but recently moved to quantize_ to better align with other inference APIs
there are some requirements which are important for float8 today and are not yet easy in other torchao APIs, such as: persistent state (delayed scaling), distributed integrations, extensibility to larger graphs than ops surrounding a linear layer
next steps
We need to figure out the requirements for training, inference, and known future use cases for both. Then, we should align on how to best structure the torchao APIs to meet these requirements. Stay tuned, we will get this going after PTC 2024.
The text was updated successfully, but these errors were encountered:
I think we could separate the implementation for the model prepared for training and the model for inference. something like the following:
model = ...
# can be implemented with hooks, or module swaps etc. more friendly with training
prepare_for_training_(model, ...)
# training
# use tensor subclass for everything for better serialization/deserialization UX
convert_to_inference_(model)
context
Today,
torchao.float8
has a separate API from the rest of torchao. This is for historical reasons:torchao.float8
but recently moved toquantize_
to better align with other inference APIsnext steps
We need to figure out the requirements for training, inference, and known future use cases for both. Then, we should align on how to best structure the torchao APIs to meet these requirements. Stay tuned, we will get this going after PTC 2024.
The text was updated successfully, but these errors were encountered: