@@ -798,13 +798,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
798798 PatternRewriter &rewriter, ValueRange values,
799799 SmallVectorImpl<Value> &remapped);
800800
801- // / Returns true if the given operation is ignored, and does not need to be
801+ // / Return " true" if the given operation is ignored, and does not need to be
802802 // / converted.
803803 bool isOpIgnored (Operation *op) const ;
804804
805- // / Recursively marks the nested operations under 'op' as ignored. This
806- // / removes them from being considered for legalization.
807- void markNestedOpsIgnored (Operation *op);
805+ // / Return "true" if the given operation was replaced or erased.
806+ bool wasOpReplaced (Operation *op) const ;
808807
809808 // ===--------------------------------------------------------------------===//
810809 // Type Conversion
@@ -946,18 +945,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
946945 // / Ordered list of block operations (creations, splits, motions).
947946 SmallVector<std::unique_ptr<IRRewrite>> rewrites;
948947
949- // / A set of operations that should no longer be considered for legalization,
950- // / but were not directly replace/erased/etc. by a pattern. These are
951- // / generally child operations of other operations who were
952- // / replaced/erased/etc. This is not meant to be an exhaustive list of all
953- // / operations, but the minimal set that can be used to detect if a given
954- // / operation should be `ignored`. For example, we may add the operations that
955- // / define non-empty regions to the set, but not any of the others. This
956- // / simplifies the amount of memory needed as we can query if the parent
957- // / operation was ignored.
948+ // / A set of operations that should no longer be considered for legalization.
949+ // / E.g., ops that are recursively legal. Ops that were replaced/erased are
950+ // / tracked separately.
958951 SetVector<Operation *> ignoredOps;
959952
960- // A set of operations that were erased.
953+ // / A set of operations that were replaced/erased. Such ops are not erased
954+ // / immediately but only when the dialect conversion succeeds. In the mean
955+ // / time, they should no longer be considered for legalization and any attempt
956+ // / to modify/access them is invalid rewriter API usage.
961957 SetVector<Operation *> replacedOps;
962958
963959 // / The current type converter, or nullptr if no type converter is currently
@@ -1237,24 +1233,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
12371233 return success ();
12381234}
12391235
1240- // TODO: This function is a misnomer. It does not actually check if `op` is in
1241- // `ignoredOps`.
12421236bool ConversionPatternRewriterImpl::isOpIgnored (Operation *op) const {
1243- // Check to see if this operation or the parent operation is ignored .
1244- return ignoredOps .count (op-> getParentOp ()) || replacedOps .count (op);
1237+ // Check to see if this operation is ignored or was replaced .
1238+ return replacedOps .count (op) || ignoredOps .count (op);
12451239}
12461240
1247- void ConversionPatternRewriterImpl::markNestedOpsIgnored (Operation *op) {
1248- // Walk this operation and collect nested operations that define non-empty
1249- // regions. We mark such operations as 'ignored' so that we know we don't have
1250- // to convert them, or their nested ops.
1251- if (op->getNumRegions () == 0 )
1252- return ;
1253- op->walk ([&](Operation *op) {
1254- if (llvm::any_of (op->getRegions (),
1255- [](Region ®ion) { return !region.empty (); }))
1256- ignoredOps.insert (op);
1257- });
1241+ bool ConversionPatternRewriterImpl::wasOpReplaced (Operation *op) const {
1242+ // Check to see if this operation was replaced.
1243+ return replacedOps.count (op);
12581244}
12591245
12601246// ===----------------------------------------------------------------------===//
@@ -1476,6 +1462,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14761462 logger.startLine () << " ** Insert : '" << op->getName () << " '(" << op
14771463 << " )\n " ;
14781464 });
1465+ assert (!wasOpReplaced (op->getParentOp ()) &&
1466+ " attempting to insert into a block within a replaced/erased op" );
1467+
14791468 if (!previous.isSet ()) {
14801469 // This is a newly created op.
14811470 appendRewrite<CreateOperationRewrite>(op);
@@ -1490,7 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
14901479void ConversionPatternRewriterImpl::notifyOpReplaced (Operation *op,
14911480 ValueRange newValues) {
14921481 assert (newValues.size () == op->getNumResults ());
1493- assert (!replacedOps .contains (op) && " operation was already replaced" );
1482+ assert (!ignoredOps .contains (op) && " operation was already replaced" );
14941483
14951484 // Track if any of the results changed, e.g. erased and replaced with null.
14961485 bool resultChanged = false ;
@@ -1509,10 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
15091498 appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
15101499 resultChanged);
15111500
1512- // Mark this operation as recursively ignored so that we don't need to
1513- // convert any nested operations.
1514- replacedOps.insert (op);
1515- markNestedOpsIgnored (op);
1501+ // Mark this operation and all nested ops as replaced.
1502+ op->walk ([&](Operation *op) { replacedOps.insert (op); });
15161503}
15171504
15181505void ConversionPatternRewriterImpl::notifyBlockIsBeingErased (Block *block) {
@@ -1523,6 +1510,9 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
15231510
15241511void ConversionPatternRewriterImpl::notifyBlockInserted (
15251512 Block *block, Region *previous, Region::iterator previousIt) {
1513+ assert (!wasOpReplaced (block->getParentOp ()) &&
1514+ " attempting to insert into a region within a replaced/erased op" );
1515+
15261516 if (!previous) {
15271517 // This is a newly created block.
15281518 appendRewrite<CreateBlockRewrite>(block);
@@ -1604,6 +1594,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
16041594}
16051595
16061596void ConversionPatternRewriter::eraseBlock (Block *block) {
1597+ assert (!impl->wasOpReplaced (block->getParentOp ()) &&
1598+ " attempting to erase a block within a replaced/erased op" );
1599+
16071600 // Mark all ops for erasure.
16081601 for (Operation &op : *block)
16091602 eraseOp (&op);
@@ -1619,18 +1612,27 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
16191612Block *ConversionPatternRewriter::applySignatureConversion (
16201613 Region *region, TypeConverter::SignatureConversion &conversion,
16211614 const TypeConverter *converter) {
1615+ assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1616+ " attempting to apply a signature conversion to a block within a "
1617+ " replaced/erased op" );
16221618 return impl->applySignatureConversion (region, conversion, converter);
16231619}
16241620
16251621FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes (
16261622 Region *region, const TypeConverter &converter,
16271623 TypeConverter::SignatureConversion *entryConversion) {
1624+ assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1625+ " attempting to apply a signature conversion to a block within a "
1626+ " replaced/erased op" );
16281627 return impl->convertRegionTypes (region, converter, entryConversion);
16291628}
16301629
16311630LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes (
16321631 Region *region, const TypeConverter &converter,
16331632 ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
1633+ assert (!impl->wasOpReplaced (region->getParentOp ()) &&
1634+ " attempting to apply a signature conversion to a block within a "
1635+ " replaced/erased op" );
16341636 return impl->convertNonEntryRegionTypes (region, converter, blockConversions);
16351637}
16361638
@@ -1665,6 +1667,8 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
16651667
16661668Block *ConversionPatternRewriter::splitBlock (Block *block,
16671669 Block::iterator before) {
1670+ assert (!impl->wasOpReplaced (block->getParentOp ()) &&
1671+ " attempting to split a block within a replaced/erased op" );
16681672 auto *continuation = block->splitBlock (before);
16691673 impl->notifySplitBlock (block, continuation);
16701674 return continuation;
@@ -1673,15 +1677,19 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
16731677void ConversionPatternRewriter::inlineBlockBefore (Block *source, Block *dest,
16741678 Block::iterator before,
16751679 ValueRange argValues) {
1680+ #ifndef NDEBUG
16761681 assert (argValues.size () == source->getNumArguments () &&
16771682 " incorrect # of argument replacement values" );
1678- #ifndef NDEBUG
1683+ assert (!impl->wasOpReplaced (source->getParentOp ()) &&
1684+ " attempting to inline a block from a replaced/erased op" );
1685+ assert (!impl->wasOpReplaced (dest->getParentOp ()) &&
1686+ " attempting to inline a block into a replaced/erased op" );
16791687 auto opIgnored = [&](Operation *op) { return impl->isOpIgnored (op); };
1680- #endif // NDEBUG
16811688 // The source block will be deleted, so it should not have any users (i.e.,
16821689 // there should be no predecessors).
16831690 assert (llvm::all_of (source->getUsers (), opIgnored) &&
16841691 " expected 'source' to have no predecessors" );
1692+ #endif // NDEBUG
16851693
16861694 impl->notifyBlockBeingInlined (dest, source, before);
16871695 for (auto it : llvm::zip (source->getArguments (), argValues))
@@ -1691,13 +1699,17 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
16911699}
16921700
16931701void ConversionPatternRewriter::startOpModification (Operation *op) {
1702+ assert (!impl->wasOpReplaced (op) &&
1703+ " attempting to modify a replaced/erased op" );
16941704#ifndef NDEBUG
16951705 impl->pendingRootUpdates .insert (op);
16961706#endif
16971707 impl->appendRewrite <ModifyOperationRewrite>(op);
16981708}
16991709
17001710void ConversionPatternRewriter::finalizeOpModification (Operation *op) {
1711+ assert (!impl->wasOpReplaced (op) &&
1712+ " attempting to modify a replaced/erased op" );
17011713 PatternRewriter::finalizeOpModification (op);
17021714 // There is nothing to do here, we only need to track the operation at the
17031715 // start of the update.
@@ -1912,8 +1924,13 @@ OperationLegalizer::legalize(Operation *op,
19121924
19131925 // If this operation is recursively legal, mark its children as ignored so
19141926 // that we don't consider them for legalization.
1915- if (legalityInfo->isRecursivelyLegal )
1916- rewriter.getImpl ().markNestedOpsIgnored (op);
1927+ if (legalityInfo->isRecursivelyLegal ) {
1928+ op->walk ([&](Operation *nested) {
1929+ if (op != nested)
1930+ rewriter.getImpl ().ignoredOps .insert (nested);
1931+ });
1932+ }
1933+
19171934 return success ();
19181935 }
19191936
0 commit comments