Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 189 additions & 2 deletions csrc/logical_domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
// clang-format on
#include <debug.h>
#include <ir/internal_nodes.h>
#include <ir/iostream.h>
#include <ir/utils.h>
#include <iter_visitor.h>
Expand Down Expand Up @@ -721,6 +722,192 @@ std::string UnmappableReductionDomains::toString() const {
return ss.str();
}

class BackwardVisitorNoDefaultHandlers : public BackwardVisitor {
public:
BackwardVisitorNoDefaultHandlers(bool map_through_reduction)
: BackwardVisitor(map_through_reduction) {}

protected:
using BackwardVisitor::handle;

#define M(e) \
void handle(e* uop) override { \
NVF_THROW("Unhandled expression type: " #e); \
}
DISPATCH_FOR_ALL_EXPRS(M);
#undef M
};

// Create a DisjointSets of logical IterDomains by traversing the
// current fusion entirely. IterDomains that can be mapped each
// other with computeAt are grouped into the same subset in the
// DisjointSets.
class ComputeAtLogicalDomainMapBuilder
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This did not need to be in the header so I moved it to the .cpp. It is unchanged other than deriving from BackwardVisitorNoDefaultHandlers and implementing a few missing handlers.

: private BackwardVisitorNoDefaultHandlers {
public:
explicit ComputeAtLogicalDomainMapBuilder(
ComputeAtLogicalDomainMap& logical_map,
bool map_through_reduction = false);

private:
//! Initialize the bcast map for fusion outputs
void initializeBcastMap(const TensorView* tv, const IterDomain* id);

//! Set a pair of producer-consumer domain keys as mappable
void setMapped(const DomainKey& producer, const DomainKey& consumer);

//! Records two domains are invalid to map
void setInvalid(const DomainKey& key1, const DomainKey& key2);

//! Check if no pair of domains is invalid to map
bool isInvalid(const DomainKeySet& domains) const;

//! Track a pair of producer-consumer domains as potentially mappable. Inserts
//! entries into pending_map_, but does not add anything into the logical_map_
//! (added when handle is called on a TensorView). Maybe mapped will, however,
//! immediately propagate broadcast iter domains.
void setMaybeMapped(
const TensorDomain* producer_td,
const IterDomain* producer_id,
const TensorDomain* consumer_td,
const IterDomain* consumer_id);

void addToPendingList(const DomainKey& producer, const DomainKey& consumer);

//! Map pointwise IterDomains from inputs of expressions to outputs.
//! Do not map reduction IterDomains in inputs.
void mapPointwiseLikeOp(Expr* e);

using BackwardVisitorNoDefaultHandlers::handle;

void dispatch(Expr* e) override;

void handle(UnaryOp* uop) override {
mapPointwiseLikeOp(uop);
}

void handle(BinaryOp* bop) override {
mapPointwiseLikeOp(bop);
}

void handle(TernaryOp* top) override {
mapPointwiseLikeOp(top);
}

void handle(RNGOp* top) override;

void handle(SelectOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(IndexSelectOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(GatherOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(ReductionOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(GroupedReductionOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(WelfordOp* wop) override {
mapPointwiseLikeOp(wop);
}

void handle(LoadStoreOp* ldst) override {
mapPointwiseLikeOp(ldst);
}

void handle(MmaOp* wop) override {
mapPointwiseLikeOp(wop);
}

void handle(ViewOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(ViewAsScalar* op) override;

void handle(BroadcastOp* op) override;

void handle(SqueezeOp* op) override;

void handle(ExpandOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(RepeatOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(PadOp* op) override {
// For compute-at, padded id should be mapped
mapPointwiseLikeOp(op);
}

void handle(SliceOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(CatOp* op) override {
// For compute-at, concat id should be mapped
mapPointwiseLikeOp(op);
}

void handle(EmbeddingFwdOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(FullOp* op) override {}

void handle(GetItem* op) override {
mapPointwiseLikeOp(op);
}

void handle(IotaOp* op) override {}

void handle(EyeOp* op) override {}
Comment on lines +867 to +875
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added these handlers


void handle(ScatterOp* op) override {
// TODO: I think we should map all dims like pointwise here other than
// op->dim()
}
Comment on lines +877 to +880
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this handler. Maybe it should be mapPointwiseOp(op) or similar instead, but since it was missing, this empty override matches current behavior.


void handle(TensorView* tv) override;

//! Maps all pending mappings.
//! This is called for each of TensorViews in a backward traversal,
//! recursively building mappings from the output tensors to the
//! input tensors.
void mapAllPendingMappings(const DomainKey& key);

//! Maps all pending mappings for id of td. When id is a broadcast,
//! mapping is done separately for each concrete domain.
void mapAllPendingMappings(const TensorDomain* td, IterDomain* id);

bool safeToMap(const DomainKeySet& domains);

private:
ComputeAtLogicalDomainMap& logical_map_;
//! Keep track of what we want to try and map
DomainKeyMap<DomainKeySet> pending_map_;
std::unordered_set<Expr*> visited_;
//! Helper class to find invalid mappings due to reductions
UnmappableReductionDomains incompatible_domains_;
//! Running vector of domain pairs that are invalid to map
std::vector<std::pair<DomainKey, DomainKey>> invalid_mappings_;

//! Disable UnmappableReductions check, should
//! always be false for compute_at use cases
bool map_through_reduction_ = false;
};

void ComputeAtLogicalDomainMap::build(bool map_through_reduction) {
// Make sure we start from scratch. Throw away previous results.
eq_set_.clear();
Expand Down Expand Up @@ -1005,7 +1192,7 @@ std::string ComputeAtLogicalDomainMap::toString() const {
ComputeAtLogicalDomainMapBuilder::ComputeAtLogicalDomainMapBuilder(
ComputeAtLogicalDomainMap& logical_map,
bool map_through_reduction)
: BackwardVisitor(false),
: BackwardVisitorNoDefaultHandlers(false),
logical_map_(logical_map),
map_through_reduction_(map_through_reduction) {
Fusion* fusion = FusionGuard::getCurFusion();
Expand Down Expand Up @@ -1174,7 +1361,7 @@ void ComputeAtLogicalDomainMapBuilder::dispatch(Expr* e) {
if (visited_.find(e) != visited_.end()) {
return;
}
BackwardVisitor::dispatch(e);
BackwardVisitorNoDefaultHandlers::dispatch(e);
visited_.insert(e);
}

Expand Down
156 changes: 2 additions & 154 deletions csrc/logical_domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ class UnmappableReductionDomains : private IterVisitor {
DomainKeyMap<DomainKeySet> reduction_domain_inputs_;
};

class ComputeAtLogicalDomainMapBuilder;

//! Models logical-domain mappings for computeAt
//!
//! Two iteration domains are mapped when computeAt of one iteration
Expand Down Expand Up @@ -401,160 +403,6 @@ class NVF_API ComputeAtLogicalDomainMap : public LogicalDomainMap {
std::unordered_set<IterDomain*> window_axes_;
};

//! Create a DisjointSets of logical IterDomains by traversing the
//! current fusion entirely. IterDomains that can be mapped each
//! other with computeAt are grouped into the same subset in the
//! DisjointSets.
class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor {
public:
explicit ComputeAtLogicalDomainMapBuilder(
ComputeAtLogicalDomainMap& logical_map,
bool map_through_reduction = false);

private:
//! Initialize the bcast map for fusion outputs
void initializeBcastMap(const TensorView* tv, const IterDomain* id);

//! Set a pair of producer-consumer domain keys as mappable
void setMapped(const DomainKey& producer, const DomainKey& consumer);

//! Records two domains are invalid to map
void setInvalid(const DomainKey& key1, const DomainKey& key2);

//! Check if no pair of domains is invalid to map
bool isInvalid(const DomainKeySet& domains) const;

//! Track a pair of producer-consumer domains as potentially mappable. Inserts
//! entries into pending_map_, but does not add anything into the logical_map_
//! (added when handle is called on a TensorView). Maybe mapped will, however,
//! immediately propagate broadcast iter domains.
void setMaybeMapped(
const TensorDomain* producer_td,
const IterDomain* producer_id,
const TensorDomain* consumer_td,
const IterDomain* consumer_id);

void addToPendingList(const DomainKey& producer, const DomainKey& consumer);

//! Map pointwise IterDomains from inputs of expressions to outputs.
//! Do not map reduction IterDomains in inputs.
void mapPointwiseLikeOp(Expr* e);

using BackwardVisitor::handle;

void dispatch(Expr* e) override;

void handle(UnaryOp* uop) override {
mapPointwiseLikeOp(uop);
}

void handle(BinaryOp* bop) override {
mapPointwiseLikeOp(bop);
}

void handle(TernaryOp* top) override {
mapPointwiseLikeOp(top);
}

void handle(RNGOp* top) override;

void handle(SelectOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(IndexSelectOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(GatherOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(ReductionOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(GroupedReductionOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(WelfordOp* wop) override {
mapPointwiseLikeOp(wop);
}

void handle(LoadStoreOp* ldst) override {
mapPointwiseLikeOp(ldst);
}

void handle(MmaOp* wop) override {
mapPointwiseLikeOp(wop);
}

void handle(ViewOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(ViewAsScalar* op) override;

void handle(BroadcastOp* op) override;

void handle(SqueezeOp* op) override;

void handle(ExpandOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(RepeatOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(PadOp* op) override {
// For compute-at, padded id should be mapped
mapPointwiseLikeOp(op);
}

void handle(SliceOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(CatOp* op) override {
// For compute-at, concat id should be mapped
mapPointwiseLikeOp(op);
}

void handle(EmbeddingFwdOp* op) override {
mapPointwiseLikeOp(op);
}

void handle(TensorView* tv) override;

//! Maps all pending mappings.
//! This is called for each of TensorViews in a backward traversal,
//! recursively building mappings from the output tensors to the
//! input tensors.
void mapAllPendingMappings(const DomainKey& key);

//! Maps all pending mappings for id of td. When id is a broadcast,
//! mapping is done separately for each concrete domain.
void mapAllPendingMappings(const TensorDomain* td, IterDomain* id);

bool safeToMap(const DomainKeySet& domains);

private:
ComputeAtLogicalDomainMap& logical_map_;
//! Keep track of what we want to try and map
DomainKeyMap<DomainKeySet> pending_map_;
std::unordered_set<Expr*> visited_;
//! Helper class to find invalid mappings due to reductions
UnmappableReductionDomains incompatible_domains_;
//! Running vector of domain pairs that are invalid to map
std::vector<std::pair<DomainKey, DomainKey>> invalid_mappings_;

//! Disable UnmappableReductions check, should
//! always be false for compute_at use cases
bool map_through_reduction_ = false;
};

//! Maps logical domains of an entire fusion. Does not map broadcast
//! domains with non-broadcast domains.
class NVF_API ExactLogicalDomainMap : public LogicalDomainMap {
Expand Down