-
Notifications
You must be signed in to change notification settings - Fork 74
Throw exception in ComputeAtLogicalDomainMapBuilder::handle by default #4282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d3b1ebb
fc7dbae
3bd5f4e
e33bfb8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| */ | ||
| // clang-format on | ||
| #include <debug.h> | ||
| #include <ir/internal_nodes.h> | ||
| #include <ir/iostream.h> | ||
| #include <ir/utils.h> | ||
| #include <iter_visitor.h> | ||
|
|
@@ -721,6 +722,192 @@ std::string UnmappableReductionDomains::toString() const { | |
| return ss.str(); | ||
| } | ||
|
|
||
| class BackwardVisitorNoDefaultHandlers : public BackwardVisitor { | ||
| public: | ||
| BackwardVisitorNoDefaultHandlers(bool map_through_reduction) | ||
| : BackwardVisitor(map_through_reduction) {} | ||
|
|
||
| protected: | ||
| using BackwardVisitor::handle; | ||
|
|
||
| #define M(e) \ | ||
| void handle(e* uop) override { \ | ||
| NVF_THROW("Unhandled expression type: " #e); \ | ||
| } | ||
| DISPATCH_FOR_ALL_EXPRS(M); | ||
| #undef M | ||
| }; | ||
|
|
||
| // Create a DisjointSets of logical IterDomains by traversing the | ||
| // current fusion entirely. IterDomains that can be mapped each | ||
| // other with computeAt are grouped into the same subset in the | ||
| // DisjointSets. | ||
| class ComputeAtLogicalDomainMapBuilder | ||
| : private BackwardVisitorNoDefaultHandlers { | ||
| public: | ||
| explicit ComputeAtLogicalDomainMapBuilder( | ||
| ComputeAtLogicalDomainMap& logical_map, | ||
| bool map_through_reduction = false); | ||
|
|
||
| private: | ||
| //! Initialize the bcast map for fusion outputs | ||
| void initializeBcastMap(const TensorView* tv, const IterDomain* id); | ||
|
|
||
| //! Set a pair of producer-consumer domain keys as mappable | ||
| void setMapped(const DomainKey& producer, const DomainKey& consumer); | ||
|
|
||
| //! Records two domains are invalid to map | ||
| void setInvalid(const DomainKey& key1, const DomainKey& key2); | ||
|
|
||
| //! Check if no pair of domains is invalid to map | ||
| bool isInvalid(const DomainKeySet& domains) const; | ||
|
|
||
| //! Track a pair of producer-consumer domains as potentially mappable. Inserts | ||
| //! entries into pending_map_, but does not add anything into the logical_map_ | ||
| //! (added when handle is called on a TensorView). Maybe mapped will, however, | ||
| //! immediately propagate broadcast iter domains. | ||
| void setMaybeMapped( | ||
| const TensorDomain* producer_td, | ||
| const IterDomain* producer_id, | ||
| const TensorDomain* consumer_td, | ||
| const IterDomain* consumer_id); | ||
|
|
||
| void addToPendingList(const DomainKey& producer, const DomainKey& consumer); | ||
|
|
||
| //! Map pointwise IterDomains from inputs of expressions to outputs. | ||
| //! Do not map reduction IterDomains in inputs. | ||
| void mapPointwiseLikeOp(Expr* e); | ||
|
|
||
| using BackwardVisitorNoDefaultHandlers::handle; | ||
|
|
||
| void dispatch(Expr* e) override; | ||
|
|
||
| void handle(UnaryOp* uop) override { | ||
| mapPointwiseLikeOp(uop); | ||
| } | ||
|
|
||
| void handle(BinaryOp* bop) override { | ||
| mapPointwiseLikeOp(bop); | ||
| } | ||
|
|
||
| void handle(TernaryOp* top) override { | ||
| mapPointwiseLikeOp(top); | ||
| } | ||
|
|
||
| void handle(RNGOp* top) override; | ||
|
|
||
| void handle(SelectOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(IndexSelectOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(GatherOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(ReductionOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(GroupedReductionOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(WelfordOp* wop) override { | ||
| mapPointwiseLikeOp(wop); | ||
| } | ||
|
|
||
| void handle(LoadStoreOp* ldst) override { | ||
| mapPointwiseLikeOp(ldst); | ||
| } | ||
|
|
||
| void handle(MmaOp* wop) override { | ||
| mapPointwiseLikeOp(wop); | ||
| } | ||
|
|
||
| void handle(ViewOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(ViewAsScalar* op) override; | ||
|
|
||
| void handle(BroadcastOp* op) override; | ||
|
|
||
| void handle(SqueezeOp* op) override; | ||
|
|
||
| void handle(ExpandOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(RepeatOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(PadOp* op) override { | ||
| // For compute-at, padded id should be mapped | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(SliceOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(CatOp* op) override { | ||
| // For compute-at, concat id should be mapped | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(EmbeddingFwdOp* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(FullOp* op) override {} | ||
|
|
||
| void handle(GetItem* op) override { | ||
| mapPointwiseLikeOp(op); | ||
| } | ||
|
|
||
| void handle(IotaOp* op) override {} | ||
|
|
||
| void handle(EyeOp* op) override {} | ||
|
Comment on lines
+867
to
+875
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added these handlers |
||
|
|
||
| void handle(ScatterOp* op) override { | ||
| // TODO: I think we should map all dims like pointwise here other than | ||
| // op->dim() | ||
| } | ||
|
Comment on lines
+877
to
+880
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added this handler. Maybe it should be |
||
|
|
||
| void handle(TensorView* tv) override; | ||
|
|
||
| //! Maps all pending mappings. | ||
| //! This is called for each of TensorViews in a backward traversal, | ||
| //! recursively building mappings from the output tensors to the | ||
| //! input tensors. | ||
| void mapAllPendingMappings(const DomainKey& key); | ||
|
|
||
| //! Maps all pending mappings for id of td. When id is a broadcast, | ||
| //! mapping is done separately for each concrete domain. | ||
| void mapAllPendingMappings(const TensorDomain* td, IterDomain* id); | ||
|
|
||
| bool safeToMap(const DomainKeySet& domains); | ||
|
|
||
| private: | ||
| ComputeAtLogicalDomainMap& logical_map_; | ||
| //! Keep track of what we want to try and map | ||
| DomainKeyMap<DomainKeySet> pending_map_; | ||
| std::unordered_set<Expr*> visited_; | ||
| //! Helper class to find invalid mappings due to reductions | ||
| UnmappableReductionDomains incompatible_domains_; | ||
| //! Running vector of domain pairs that are invalid to map | ||
| std::vector<std::pair<DomainKey, DomainKey>> invalid_mappings_; | ||
|
|
||
| //! Disable UnmappableReductions check, should | ||
| //! always be false for compute_at use cases | ||
| bool map_through_reduction_ = false; | ||
| }; | ||
|
|
||
| void ComputeAtLogicalDomainMap::build(bool map_through_reduction) { | ||
| // Make sure we start from scratch. Throw away previous results. | ||
| eq_set_.clear(); | ||
|
|
@@ -1005,7 +1192,7 @@ std::string ComputeAtLogicalDomainMap::toString() const { | |
| ComputeAtLogicalDomainMapBuilder::ComputeAtLogicalDomainMapBuilder( | ||
| ComputeAtLogicalDomainMap& logical_map, | ||
| bool map_through_reduction) | ||
| : BackwardVisitor(false), | ||
| : BackwardVisitorNoDefaultHandlers(false), | ||
| logical_map_(logical_map), | ||
| map_through_reduction_(map_through_reduction) { | ||
| Fusion* fusion = FusionGuard::getCurFusion(); | ||
|
|
@@ -1174,7 +1361,7 @@ void ComputeAtLogicalDomainMapBuilder::dispatch(Expr* e) { | |
| if (visited_.find(e) != visited_.end()) { | ||
| return; | ||
| } | ||
| BackwardVisitor::dispatch(e); | ||
| BackwardVisitorNoDefaultHandlers::dispatch(e); | ||
| visited_.insert(e); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This did not need to be in the header so I moved it to the .cpp. It is unchanged other than deriving from
BackwardVisitorNoDefaultHandlersand implementing a few missing handlers.