-
Notifications
You must be signed in to change notification settings - Fork 81
Throw exception in ComputeAtLogicalDomainMapBuilder::handle by default #4282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jacobhinkle
wants to merge
4
commits into
main
Choose a base branch
from
jh/no_default_handle_domainmap_builder
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
d3b1ebb
Throw exception in ComputeAtLogicalDomainMapBuilder::handle by default
jacobhinkle fc7dbae
Fix error message
jacobhinkle 3bd5f4e
Add handlers for FullOp, IotaOp, EyeOp, GetItem
jacobhinkle e33bfb8
Add trivial override for ScatterOp
jacobhinkle File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| */ | ||
| // clang-format on | ||
| #include <debug.h> | ||
| #include <ir/internal_nodes.h> | ||
| #include <ir/iostream.h> | ||
| #include <ir/utils.h> | ||
| #include <iter_visitor.h> | ||
|
|
@@ -721,6 +722,187 @@ std::string UnmappableReductionDomains::toString() const { | |
| return ss.str(); | ||
| } | ||
|
|
||
| class BackwardVisitorNoDefaultHandlers : public BackwardVisitor { | ||
| public: | ||
| BackwardVisitorNoDefaultHandlers(bool map_through_reduction) | ||
| : BackwardVisitor(map_through_reduction) {} | ||
|
|
||
| protected: | ||
| using BackwardVisitor::handle; | ||
|
|
||
| #define M(e) \ | ||
| void handle(e* uop) override { \ | ||
| NVF_THROW("Unhandled expression type: " #e); \ | ||
| } | ||
| DISPATCH_FOR_ALL_EXPRS(M); | ||
| #undef M | ||
| }; | ||
|
|
||
| // Create a DisjointSets of logical IterDomains by traversing the | ||
| // current fusion entirely. IterDomains that can be mapped each | ||
| // other with computeAt are grouped into the same subset in the | ||
| // DisjointSets. | ||
| class ComputeAtLogicalDomainMapBuilder | ||
| : private BackwardVisitorNoDefaultHandlers { | ||
| public: | ||
| explicit ComputeAtLogicalDomainMapBuilder( | ||
| ComputeAtLogicalDomainMap& logical_map, | ||
| bool map_through_reduction = false); | ||
|
|
||
| private: | ||
| //! Initialize the bcast map for fusion outputs | ||
| void initializeBcastMap(const TensorView* tv, const IterDomain* id); | ||
|
|
||
| //! Set a pair of producer-consumer domain keys as mappable | ||
| void setMapped(const DomainKey& producer, const DomainKey& consumer); | ||
|
|
||
| //! Records two domains are invalid to map | ||
| void setInvalid(const DomainKey& key1, const DomainKey& key2); | ||
|
|
||
| //! Check if no pair of domains is invalid to map | ||
| bool isInvalid(const DomainKeySet& domains) const; | ||
|
|
||
| //! Track a pair of producer-consumer domains as potentially mappable. Inserts | ||
| //! entries into pending_map_, but does not add anything into the logical_map_ | ||
| //! (added when handle is called on a TensorView). Maybe mapped will, however, | ||
| //! immediately propagate broadcast iter domains. | ||
| void setMaybeMapped( | ||
| const TensorDomain* producer_td, | ||
| const IterDomain* producer_id, | ||
| const TensorDomain* consumer_td, | ||
| const IterDomain* consumer_id); | ||
|
|
||
| void addToPendingList(const DomainKey& producer, const DomainKey& consumer); | ||
|
|
||
| //! Map pointwise IterDomains from inputs of expressions to outputs. | ||
| //! Do not map reduction IterDomains in inputs. | ||
| void mapPointwiseLikeOp(Expr* e); | ||
|
|
||
| using BackwardVisitorNoDefaultHandlers::handle; | ||
|
|
||
| void dispatch(Expr* e) override; | ||
|
|
||
| void handle(UnaryOp* uop) override { | ||
| mapPointwiseLikeOp(uop); | ||
| } | ||
|
|
||
| void handle(BinaryOp* bop) override { | ||
| mapPointwiseLikeOp(bop); | ||
| } | ||
|
|
||
| void handle(TernaryOp* top) override { | ||
| mapPointwiseLikeOp(top); | ||
| } | ||
|
|
||
| void handle(RNGOp* top) override; | ||
|
|
||
| void handle(SelectOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(IndexSelectOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(GatherOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(ReductionOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(GroupedReductionOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(WelfordOp* wop) override { | ||
| mapPointwiseLikeOp(wop); | ||
| } | ||
|
|
||
| void handle(LoadStoreOp* ldst) override { | ||
| mapPointwiseLikeOp(ldst); | ||
| } | ||
|
|
||
| void handle(MmaOp* wop) override { | ||
| mapPointwiseLikeOp(wop); | ||
| } | ||
|
|
||
| void handle(ViewOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(ViewAsScalar* op) override; | ||
|
|
||
| void handle(BroadcastOp* op) override; | ||
|
|
||
| void handle(SqueezeOp* op) override; | ||
|
|
||
| void handle(ExpandOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(RepeatOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(PadOp* op) override { | ||
| // For compute-at, padded id should be mapped | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(SliceOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(CatOp* op) override { | ||
| // For compute-at, concat id should be mapped | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(EmbeddingFwdOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(FullOp* op) override {} | ||
|
|
||
| void handle(GetItem* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(IotaOp* op) override {} | ||
|
|
||
| void handle(EyeOp* op) override {} | ||
|
Comment on lines
+867
to
+875
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added these handlers |
||
|
|
||
| void handle(TensorView* tv) override; | ||
|
|
||
| //! Maps all pending mappings. | ||
| //! This is called for each of TensorViews in a backward traversal, | ||
| //! recursively building mappings from the output tensors to the | ||
| //! input tensors. | ||
| void mapAllPendingMappings(const DomainKey& key); | ||
|
|
||
| //! Maps all pending mappings for id of td. When id is a broadcast, | ||
| //! mapping is done separately for each concrete domain. | ||
| void mapAllPendingMappings(const TensorDomain* td, IterDomain* id); | ||
|
|
||
| bool safeToMap(const DomainKeySet& domains); | ||
|
|
||
| private: | ||
| ComputeAtLogicalDomainMap& logical_map_; | ||
| //! Keep track of what we want to try and map | ||
| DomainKeyMap<DomainKeySet> pending_map_; | ||
| std::unordered_set<Expr*> visited_; | ||
| //! Helper class to find invalid mappings due to reductions | ||
| UnmappableReductionDomains incompatible_domains_; | ||
| //! Running vector of domain pairs that are invalid to map | ||
| std::vector<std::pair<DomainKey, DomainKey>> invalid_mappings_; | ||
|
|
||
| //! Disable UnmappableReductions check, should | ||
| //! always be false for compute_at use cases | ||
| bool map_through_reduction_ = false; | ||
| }; | ||
|
|
||
| void ComputeAtLogicalDomainMap::build(bool map_through_reduction) { | ||
| // Make sure we start from scratch. Throw away previous results. | ||
| eq_set_.clear(); | ||
|
|
@@ -1005,7 +1187,7 @@ std::string ComputeAtLogicalDomainMap::toString() const { | |
| ComputeAtLogicalDomainMapBuilder::ComputeAtLogicalDomainMapBuilder( | ||
| ComputeAtLogicalDomainMap& logical_map, | ||
| bool map_through_reduction) | ||
| : BackwardVisitor(false), | ||
| : BackwardVisitorNoDefaultHandlers(false), | ||
| logical_map_(logical_map), | ||
| map_through_reduction_(map_through_reduction) { | ||
| Fusion* fusion = FusionGuard::getCurFusion(); | ||
|
|
@@ -1174,7 +1356,7 @@ void ComputeAtLogicalDomainMapBuilder::dispatch(Expr* e) { | |
| if (visited_.find(e) != visited_.end()) { | ||
| return; | ||
| } | ||
| BackwardVisitor::dispatch(e); | ||
| BackwardVisitorNoDefaultHandlers::dispatch(e); | ||
| visited_.insert(e); | ||
| } | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This did not need to be in the header so I moved it to the .cpp. It is unchanged other than deriving from
BackwardVisitorNoDefaultHandlersand implementing a few missing handlers.