-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add decorator for custom op and inductor decomp registration #434
Merged
+173
−22
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
"skip_if_compute_capability_less_than", | ||
"benchmark_torch_function_in_microseconds", | ||
"find_multiple", | ||
"_register_custom_op", | ||
"get_model_size_in_bytes", | ||
"unwrap_tensor_subclass", | ||
"TORCH_VERSION_AFTER_2_2", | ||
|
@@ -65,7 +66,7 @@ def wrapper(*args, **kwargs): | |
|
||
def benchmark_torch_function_in_microseconds(f, *args, **kwargs): | ||
import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded | ||
|
||
# Manual warmup | ||
f(*args, **kwargs) | ||
f(*args, **kwargs) | ||
|
@@ -84,6 +85,55 @@ def find_multiple(n: int, *args: Tuple[int]) -> int: | |
return n | ||
return n + k - (n % k) | ||
|
||
def _register_custom_op(lib): | ||
"""This decorator is used to preserve some high level operators for torch.export.export | ||
while still allow them to be decomposed for inductor path | ||
|
||
requirement: make sure `fn.__name__[1:]` is the operator name you want to register | ||
|
||
NOTE: This should be applied at the top, after all other decorators have been applied | ||
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input, | ||
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make | ||
sense for downstream system (like executorch) to accept as well | ||
|
||
Example: | ||
lib = torch.library.Library("my_namespace', "FRAGMENT") | ||
|
||
register_custom_op = _register_custom_op(lib) | ||
|
||
@register_custom_op | ||
def _the_op_that_needs_to_be_preserved(...) | ||
... | ||
|
||
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as | ||
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after | ||
# torch.export.export / torch._export.capture_pre_autograd_graph | ||
|
||
""" | ||
from torch._inductor.decomposition import register_decomposition | ||
|
||
def decorator(fn): | ||
if TORCH_VERSION_AFTER_2_5: | ||
from torch._library.infer_schema import infer_schema | ||
|
||
# expecting fn.__name__ starts with `_` and we want to take the rest | ||
# to be the name of the custom op | ||
assert fn.__name__[0] == "_", f"Expecting function name starts with `_`, got {fn.__name__}" | ||
assert not any(c in fn.__name__ for c in ".<>"), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}" | ||
op_name = fn.__name__[1:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you assert there is no "." or "<" or ">" in fn.name? this can happen with lambdas or local functions |
||
schema = op_name + infer_schema(fn) | ||
lib.define(schema) | ||
lib.impl(op_name, fn, "CompositeImplicitAutograd") | ||
|
||
lib_namespace = lib.ns | ||
op = getattr(getattr(torch.ops, lib_namespace), op_name) | ||
register_decomposition([op])(fn) | ||
return op | ||
else: | ||
return fn | ||
|
||
return decorator | ||
|
||
def get_model_size_in_bytes(model, ignore_embeddings=False): | ||
""" | ||
Returns the model size in bytes. The option to ignore embeddings | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this checking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is because right now we will only see these ops for int8da_int8w quantization, other types of quant (e.g. int4 weight only) will call into the efficient kernels directly
we should probably figure out a path for executorch, I think we could abstract this with "layout", what would be a good name here?