Skip to content

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Apr 19, 2025

While working on #4211 I missed implementing a handler for ScanOp, leading to a tricky debugging session. @naoyam suggested we should throw an error if we do not have a handler in place for building ComputeAtMap, so that's what this PR does.

Note that BackwardVisitorNoDefaultHandler could also throw for handling Vals and we could override that using a macro as well. It could also go into iter_visitor.h, along with similar ones for the other visitor classes, if this is useful.

@jacobhinkle
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Apr 19, 2025

Review updated until commit e33bfb8

Description

  • Added BackwardVisitorNoDefaultHandlers to throw exceptions for unhandled expression types.

  • Updated ComputeAtLogicalDomainMapBuilder to inherit from BackwardVisitorNoDefaultHandlers.

  • Added specific handlers for FullOp, IotaOp, EyeOp, and GetItem.

  • Added trivial override for ScatterOp.


Changes walkthrough 📝

Relevant files
Enhancement
logical_domain_map.cpp
Add exception handling and specific handlers to
ComputeAtLogicalDomainMapBuilder

csrc/logical_domain_map.cpp

  • Introduced BackwardVisitorNoDefaultHandlers class to throw exceptions
    for unhandled expressions.
  • Modified ComputeAtLogicalDomainMapBuilder to inherit from
    BackwardVisitorNoDefaultHandlers.
  • Added specific handlers for FullOp, IotaOp, EyeOp, and GetItem.
  • Added trivial override for ScatterOp.
  • +189/-2 
    Formatting
    logical_domain_map.h
    Move ComputeAtLogicalDomainMapBuilder class definition to header

    csrc/logical_domain_map.h

  • Moved ComputeAtLogicalDomainMapBuilder class definition from .cpp to
    .h.
  • +2/-154 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Missing Tests

    The PR introduces a new class BackwardVisitorNoDefaultHandlers and modifies the behavior of ComputeAtLogicalDomainMapBuilder to throw exceptions for unhandled expression types. It is crucial to add tests that cover these changes to ensure the new behavior is correct and does not introduce regressions.

    class BackwardVisitorNoDefaultHandlers : public BackwardVisitor {
     public:
      BackwardVisitorNoDefaultHandlers(bool map_through_reduction)
          : BackwardVisitor(map_through_reduction) {}
    
     protected:
      using BackwardVisitor::handle;
    
    #define M(e)                                     \
      void handle(e* uop) override {                 \
        NVF_THROW("Unhandled expression type: " #e); \
      }
      DISPATCH_FOR_ALL_EXPRS(M);
    #undef M
    };
    
    // Create a DisjointSets of logical IterDomains by traversing the
    // current fusion entirely. IterDomains that can be mapped each
    // other with computeAt are grouped into the same subset in the
    // DisjointSets.
    class ComputeAtLogicalDomainMapBuilder
        : private BackwardVisitorNoDefaultHandlers {
     public:
      explicit ComputeAtLogicalDomainMapBuilder(
          ComputeAtLogicalDomainMap& logical_map,
          bool map_through_reduction = false);
    
     private:
      //! Initialize the bcast map for fusion outputs
      void initializeBcastMap(const TensorView* tv, const IterDomain* id);
    
      //! Set a pair of producer-consumer domain keys as mappable
      void setMapped(const DomainKey& producer, const DomainKey& consumer);
    
      //! Records two domains are invalid to map
      void setInvalid(const DomainKey& key1, const DomainKey& key2);
    
      //! Check if no pair of domains is invalid to map
      bool isInvalid(const DomainKeySet& domains) const;
    
      //! Track a pair of producer-consumer domains as potentially mappable. Inserts
      //! entries into pending_map_, but does not add anything into the logical_map_
      //! (added when handle is called on a TensorView). Maybe mapped will, however,
      //! immediately propagate broadcast iter domains.
      void setMaybeMapped(
          const TensorDomain* producer_td,
          const IterDomain* producer_id,
          const TensorDomain* consumer_td,
          const IterDomain* consumer_id);
    
      void addToPendingList(const DomainKey& producer, const DomainKey& consumer);
    
      //! Map pointwise IterDomains from inputs of expressions to outputs.
      //! Do not map reduction IterDomains in inputs.
      void mapPointwiseLikeOp(Expr* e);
    
      using BackwardVisitorNoDefaultHandlers::handle;
    
      void dispatch(Expr* e) override;
    
      void handle(UnaryOp* uop) override {
        mapPointwiseLikeOp(uop);
      }
    
      void handle(BinaryOp* bop) override {
        mapPointwiseLikeOp(bop);
      }
    
      void handle(TernaryOp* top) override {
        mapPointwiseLikeOp(top);
      }
    
      void handle(RNGOp* top) override;
    
      void handle(SelectOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(IndexSelectOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(GatherOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(ReductionOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(GroupedReductionOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(WelfordOp* wop) override {
        mapPointwiseLikeOp(wop);
      }
    
      void handle(LoadStoreOp* ldst) override {
        mapPointwiseLikeOp(ldst);
      }
    
      void handle(MmaOp* wop) override {
        mapPointwiseLikeOp(wop);
      }
    
      void handle(ViewOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(ViewAsScalar* op) override;
    
      void handle(BroadcastOp* op) override;
    
      void handle(SqueezeOp* op) override;
    
      void handle(ExpandOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(RepeatOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(PadOp* op) override {
        // For compute-at, padded id should be mapped
        mapPointwiseLikeOp(op);
      }
    
      void handle(SliceOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(CatOp* op) override {
        // For compute-at, concat id should be mapped
        mapPointwiseLikeOp(op);
      }
    
      void handle(EmbeddingFwdOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(FullOp* op) override {}
    
      void handle(GetItem* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(IotaOp* op) override {}
    
      void handle(EyeOp* op) override {}
    
      void handle(ScatterOp* op) override {
        // TODO: I think we should map all dims like pointwise here other than
        // op->dim()
      }
    
      void handle(TensorView* tv) override;
    
      //! Maps all pending mappings.
      //! This is called for each of TensorViews in a backward traversal,
      //! recursively building mappings from the output tensors to the
      //! input tensors.
      void mapAllPendingMappings(const DomainKey& key);
    
      //! Maps all pending mappings for id of td. When id is a broadcast,
    Performance Impact

    Throwing exceptions for unhandled expression types can have a performance impact. It is important to evaluate the performance implications of this change, especially in performance-critical code like Nvfuser.

    class BackwardVisitorNoDefaultHandlers : public BackwardVisitor {
     public:
      BackwardVisitorNoDefaultHandlers(bool map_through_reduction)
          : BackwardVisitor(map_through_reduction) {}
    
     protected:
      using BackwardVisitor::handle;
    
    #define M(e)                                     \
      void handle(e* uop) override {                 \
        NVF_THROW("Unhandled expression type: " #e); \
      }
      DISPATCH_FOR_ALL_EXPRS(M);
    #undef M
    };
    
    // Create a DisjointSets of logical IterDomains by traversing the
    // current fusion entirely. IterDomains that can be mapped each
    // other with computeAt are grouped into the same subset in the
    // DisjointSets.
    class ComputeAtLogicalDomainMapBuilder
        : private BackwardVisitorNoDefaultHandlers {
     public:
      explicit ComputeAtLogicalDomainMapBuilder(
          ComputeAtLogicalDomainMap& logical_map,
          bool map_through_reduction = false);
    
     private:
      //! Initialize the bcast map for fusion outputs
      void initializeBcastMap(const TensorView* tv, const IterDomain* id);
    
      //! Set a pair of producer-consumer domain keys as mappable
      void setMapped(const DomainKey& producer, const DomainKey& consumer);
    
      //! Records two domains are invalid to map
      void setInvalid(const DomainKey& key1, const DomainKey& key2);
    
      //! Check if no pair of domains is invalid to map
      bool isInvalid(const DomainKeySet& domains) const;
    
      //! Track a pair of producer-consumer domains as potentially mappable. Inserts
      //! entries into pending_map_, but does not add anything into the logical_map_
      //! (added when handle is called on a TensorView). Maybe mapped will, however,
      //! immediately propagate broadcast iter domains.
      void setMaybeMapped(
          const TensorDomain* producer_td,
          const IterDomain* producer_id,
          const TensorDomain* consumer_td,
          const IterDomain* consumer_id);
    
      void addToPendingList(const DomainKey& producer, const DomainKey& consumer);
    
      //! Map pointwise IterDomains from inputs of expressions to outputs.
      //! Do not map reduction IterDomains in inputs.
      void mapPointwiseLikeOp(Expr* e);
    
      using BackwardVisitorNoDefaultHandlers::handle;
    
      void dispatch(Expr* e) override;
    
      void handle(UnaryOp* uop) override {
        mapPointwiseLikeOp(uop);
      }
    
      void handle(BinaryOp* bop) override {
        mapPointwiseLikeOp(bop);
      }
    
      void handle(TernaryOp* top) override {
        mapPointwiseLikeOp(top);
      }
    
      void handle(RNGOp* top) override;
    
      void handle(SelectOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(IndexSelectOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(GatherOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(ReductionOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(GroupedReductionOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(WelfordOp* wop) override {
        mapPointwiseLikeOp(wop);
      }
    
      void handle(LoadStoreOp* ldst) override {
        mapPointwiseLikeOp(ldst);
      }
    
      void handle(MmaOp* wop) override {
        mapPointwiseLikeOp(wop);
      }
    
      void handle(ViewOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(ViewAsScalar* op) override;
    
      void handle(BroadcastOp* op) override;
    
      void handle(SqueezeOp* op) override;
    
      void handle(ExpandOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(RepeatOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(PadOp* op) override {
        // For compute-at, padded id should be mapped
        mapPointwiseLikeOp(op);
      }
    
      void handle(SliceOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(CatOp* op) override {
        // For compute-at, concat id should be mapped
        mapPointwiseLikeOp(op);
      }
    
      void handle(EmbeddingFwdOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(FullOp* op) override {}
    
      void handle(GetItem* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(IotaOp* op) override {}
    
      void handle(EyeOp* op) override {}
    
      void handle(ScatterOp* op) override {
        // TODO: I think we should map all dims like pointwise here other than
        // op->dim()
      }
    
      void handle(TensorView* tv) override;
    
      //! Maps all pending mappings.
      //! This is called for each of TensorViews in a backward traversal,
      //! recursively building mappings from the output tensors to the
      //! input tensors.
      void mapAllPendingMappings(const DomainKey& key);
    
      //! Maps all pending mappings for id of td. When id is a broadcast,
    Documentation

    The PR introduces significant changes to the ComputeAtLogicalDomainMapBuilder class and its behavior. It is important to update the documentation to reflect these changes and provide clear explanations of the new behavior and its implications.

    class BackwardVisitorNoDefaultHandlers : public BackwardVisitor {
     public:
      BackwardVisitorNoDefaultHandlers(bool map_through_reduction)
          : BackwardVisitor(map_through_reduction) {}
    
     protected:
      using BackwardVisitor::handle;
    
    #define M(e)                                     \
      void handle(e* uop) override {                 \
        NVF_THROW("Unhandled expression type: " #e); \
      }
      DISPATCH_FOR_ALL_EXPRS(M);
    #undef M
    };
    
    // Create a DisjointSets of logical IterDomains by traversing the
    // current fusion entirely. IterDomains that can be mapped each
    // other with computeAt are grouped into the same subset in the
    // DisjointSets.
    class ComputeAtLogicalDomainMapBuilder
        : private BackwardVisitorNoDefaultHandlers {
     public:
      explicit ComputeAtLogicalDomainMapBuilder(
          ComputeAtLogicalDomainMap& logical_map,
          bool map_through_reduction = false);
    
     private:
      //! Initialize the bcast map for fusion outputs
      void initializeBcastMap(const TensorView* tv, const IterDomain* id);
    
      //! Set a pair of producer-consumer domain keys as mappable
      void setMapped(const DomainKey& producer, const DomainKey& consumer);
    
      //! Records two domains are invalid to map
      void setInvalid(const DomainKey& key1, const DomainKey& key2);
    
      //! Check if no pair of domains is invalid to map
      bool isInvalid(const DomainKeySet& domains) const;
    
      //! Track a pair of producer-consumer domains as potentially mappable. Inserts
      //! entries into pending_map_, but does not add anything into the logical_map_
      //! (added when handle is called on a TensorView). Maybe mapped will, however,
      //! immediately propagate broadcast iter domains.
      void setMaybeMapped(
          const TensorDomain* producer_td,
          const IterDomain* producer_id,
          const TensorDomain* consumer_td,
          const IterDomain* consumer_id);
    
      void addToPendingList(const DomainKey& producer, const DomainKey& consumer);
    
      //! Map pointwise IterDomains from inputs of expressions to outputs.
      //! Do not map reduction IterDomains in inputs.
      void mapPointwiseLikeOp(Expr* e);
    
      using BackwardVisitorNoDefaultHandlers::handle;
    
      void dispatch(Expr* e) override;
    
      void handle(UnaryOp* uop) override {
        mapPointwiseLikeOp(uop);
      }
    
      void handle(BinaryOp* bop) override {
        mapPointwiseLikeOp(bop);
      }
    
      void handle(TernaryOp* top) override {
        mapPointwiseLikeOp(top);
      }
    
      void handle(RNGOp* top) override;
    
      void handle(SelectOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(IndexSelectOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(GatherOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(ReductionOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(GroupedReductionOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(WelfordOp* wop) override {
        mapPointwiseLikeOp(wop);
      }
    
      void handle(LoadStoreOp* ldst) override {
        mapPointwiseLikeOp(ldst);
      }
    
      void handle(MmaOp* wop) override {
        mapPointwiseLikeOp(wop);
      }
    
      void handle(ViewOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(ViewAsScalar* op) override;
    
      void handle(BroadcastOp* op) override;
    
      void handle(SqueezeOp* op) override;
    
      void handle(ExpandOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(RepeatOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(PadOp* op) override {
        // For compute-at, padded id should be mapped
        mapPointwiseLikeOp(op);
      }
    
      void handle(SliceOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(CatOp* op) override {
        // For compute-at, concat id should be mapped
        mapPointwiseLikeOp(op);
      }
    
      void handle(EmbeddingFwdOp* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(FullOp* op) override {}
    
      void handle(GetItem* op) override {
        mapPointwiseLikeOp(op);
      }
    
      void handle(IotaOp* op) override {}
    
      void handle(EyeOp* op) override {}
    
      void handle(ScatterOp* op) override {
        // TODO: I think we should map all dims like pointwise here other than
        // op->dim()
      }
    
      void handle(TensorView* tv) override;
    
      //! Maps all pending mappings.
      //! This is called for each of TensorViews in a backward traversal,
      //! recursively building mappings from the output tensors to the
      //! input tensors.
      void mapAllPendingMappings(const DomainKey& key);
    
      //! Maps all pending mappings for id of td. When id is a broadcast,

    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    @jacobhinkle
    Copy link
    Collaborator Author

    jacobhinkle commented Apr 19, 2025

    On a quick first run I see a couple of failures meaning we are missing some handlers, for example for FullOp.

    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    // current fusion entirely. IterDomains that can be mapped each
    // other with computeAt are grouped into the same subset in the
    // DisjointSets.
    class ComputeAtLogicalDomainMapBuilder
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This did not need to be in the header so I moved it to the .cpp. It is unchanged other than deriving from BackwardVisitorNoDefaultHandlers and implementing a few missing handlers.

    Comment on lines +867 to +875
    void handle(FullOp* op) override {}

    void handle(GetItem* op) override {
    mapPointwiseLikeOp(op);
    }

    void handle(IotaOp* op) override {}

    void handle(EyeOp* op) override {}
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Added these handlers

    @jacobhinkle
    Copy link
    Collaborator Author

    !test --diff

    Comment on lines +877 to +880
    void handle(ScatterOp* op) override {
    // TODO: I think we should map all dims like pointwise here other than
    // op->dim()
    }
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Added this handler. Maybe it should be mapPointwiseOp(op) or similar instead, but since it was missing, this empty override matches current behavior.

    @jacobhinkle jacobhinkle marked this pull request as ready for review April 19, 2025 17:00
    @jacobhinkle jacobhinkle requested a review from naoyam April 19, 2025 17:00
    @jacobhinkle
    Copy link
    Collaborator Author

    Codediff seems to be due to nvrtc behavior. Generated CUDA code is unchanged.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants