Skip to content

Fix prune output rules for intersect and except nodes#21343

Merged
feilong-liu merged 1 commit intoprestodb:masterfrom
feilong-liu:fix_prune_output
Nov 16, 2023
Merged

Fix prune output rules for intersect and except nodes#21343
feilong-liu merged 1 commit intoprestodb:masterfrom
feilong-liu:fix_prune_output

Conversation

@feilong-liu
Copy link
Contributor

@feilong-liu feilong-liu commented Nov 8, 2023

Description

Do not prune output in Intersect and Except nodes in prune unused output rule.

Motivation and Context

In presto, we implement Intersect and Except nodes as union+aggregation in ImplementIntersectAndExceptAsUnion rule.
For example, for query SELECT k1 FROM (SELECT nationkey as k1, regionkey as k2 FROM nation intersect SELECT orderkey as k1, custkey as k2 FROM orders), it will be implemented as union of aggregation over nation and orders, group by k1 and k2, and compare the count later. The example plan is as follows:

presto:tpch> explain (type distributed) SELECT k1 FROM (SELECT nationkey as k1, regionkey as k2 FROM nation intersect SELECT orderkey as k1, custkey as k2 FROM orders);
                                                                                                                                                 >
------------------------------------------------------------------------------------------------------------------------------------------------->
 Fragment 0 [SINGLE]                                                                                                                             >
     Output layout: [nationkey_12]                                                                                                               >
     Output partitioning: SINGLE []                                                                                                              >
     Stage Execution Strategy: UNGROUPED_EXECUTION                                                                                               >
     - Output[PlanNodeId 13][k1] => [nationkey_12:bigint]                                                                                        >
             k1 := nationkey_12 (1:35)                                                                                                           >
         - RemoteSource[1] => [nationkey_12:bigint]                                                                                              >
                                                                                                                                                 >
 Fragment 1 [HASH]                                                                                                                               >
     Output layout: [nationkey_12]                                                                                                               >
     Output partitioning: SINGLE []                                                                                                              >
     Stage Execution Strategy: UNGROUPED_EXECUTION                                                                                               >
     - FilterProject[PlanNodeId 272,160][filterPredicate = ((count) >= (BIGINT'1')) AND ((count_29) >= (BIGINT'1')), projectLocality = LOCAL] => >
         - Project[PlanNodeId 662][projectLocality = LOCAL] => [nationkey_12:bigint, regionkey_13:bigint, count_29:bigint, count:bigint]         >
             - Aggregate(FINAL)[nationkey_12, regionkey_13][$hashvalue][PlanNodeId 140] => [nationkey_12:bigint, regionkey_13:bigint, $hashvalue:>
                     count_29 := "presto.default.count"((count_31))                                                                              >
                     count := "presto.default.count"((count_30))                                                                                 >
                 - LocalExchange[PlanNodeId 592][HASH][$hashvalue] (nationkey_12, regionkey_13) => [nationkey_12:bigint, regionkey_13:bigint, cou>
                     - Project[PlanNodeId 660][projectLocality = LOCAL] => [nationkey_12:bigint, regionkey_13:bigint, count_30:bigint, count_31:b>
                             $hashvalue_34 := combine_hash(combine_hash(BIGINT'0', COALESCE($operator$hash_code(nationkey_12), BIGINT'0')), COALE>
                         - Project[PlanNodeId 589][projectLocality = LOCAL] => [nationkey_12:bigint, regionkey_13:bigint, count_30:bigint, count_>
                                 nationkey_12 := nationkey (1:51)                                                                                >
                                 regionkey_13 := regionkey (1:68)                                                                                >
                             - RemoteSource[2] => [nationkey:bigint, regionkey:bigint, count_30:bigint, count_31:bigint, $hashvalue_32:bigint]   >
                     - Project[PlanNodeId 661][projectLocality = LOCAL] => [nationkey_12:bigint, regionkey_13:bigint, count_30:bigint, count_31:b>
                             $hashvalue_37 := combine_hash(combine_hash(BIGINT'0', COALESCE($operator$hash_code(nationkey_12), BIGINT'0')), COALE>
                         - Project[PlanNodeId 591][projectLocality = LOCAL] => [nationkey_12:bigint, regionkey_13:bigint, count_30:bigint, count_>
                                 nationkey_12 := orderkey (1:51)                                                                                 >
                                 regionkey_13 := custkey (1:68)                                                                                  >
                             - RemoteSource[3] => [orderkey:bigint, custkey:bigint, count_30:bigint, count_31:bigint, $hashvalue_35:bigint]      >
                                                                                                                                                 >
 Fragment 2 [SOURCE]                                                                                                                             >
     Output layout: [nationkey, regionkey, count_30, count_31, $hashvalue_33]                                                                    >
     Output partitioning: HASH [nationkey, regionkey][$hashvalue_33]                                                                             >
     Stage Execution Strategy: UNGROUPED_EXECUTION                                                                                               >
     - Aggregate(PARTIAL)[nationkey, regionkey][$hashvalue_33][PlanNodeId 599] => [nationkey:bigint, regionkey:bigint, $hashvalue_33:bigint, coun>
             count_30 := "presto.default.count"((marker_23))                                                                                     >
             count_31 := "presto.default.count"((marker_24))                                                                                     >
         - ScanProject[PlanNodeId 0,137][table = TableHandle {connectorId='hive', connectorHandle='HiveTableHandle{schemaName=tpch, tableName=nat>
                 Estimates: {source: CostBasedSourceInfo, rows: 25 (1.10kB), cpu: 450.00, memory: 0.00, network: 0.00}/{source: CostBasedSourceIn>
                 marker_23 := BOOLEAN'true'                                                                                                      >
                 marker_24 := null                                                                                                               >
                 $hashvalue_33 := combine_hash(combine_hash(BIGINT'0', COALESCE($operator$hash_code(nationkey), BIGINT'0')), COALESCE($operator$h>
                 LAYOUT: tpch.nation{}                                                                                                           >
                 nationkey := nationkey:bigint:0:REGULAR (1:89)                                                                                  >
                 regionkey := regionkey:bigint:2:REGULAR (1:89)                                                                                  >
                                                                                                                                                 >
 Fragment 3 [SOURCE]                                                                                                                             >
     Output layout: [orderkey, custkey, count_30, count_31, $hashvalue_36]                                                                       >
     Output partitioning: HASH [orderkey, custkey][$hashvalue_36]                                                                                >
     Stage Execution Strategy: UNGROUPED_EXECUTION                                                                                               >
     - Aggregate(PARTIAL)[orderkey, custkey][$hashvalue_36][PlanNodeId 605] => [orderkey:bigint, custkey:bigint, $hashvalue_36:bigint, count_30:b>
             count_30 := "presto.default.count"((marker_27))                                                                                     >
             count_31 := "presto.default.count"((marker_28))                                                                                     >
         - ScanProject[PlanNodeId 3,138][table = TableHandle {connectorId='hive', connectorHandle='HiveTableHandle{schemaName=tpch, tableName=ord>
                 Estimates: {source: CostBasedSourceInfo, rows: 15000 (659.18kB), cpu: 270000.00, memory: 0.00, network: 0.00}/{source: CostBased>
                 marker_27 := null                                                                                                               >
                 marker_28 := BOOLEAN'true'                                                                                                      >
                 $hashvalue_36 := combine_hash(combine_hash(BIGINT'0', COALESCE($operator$hash_code(orderkey), BIGINT'0')), COALESCE($operator$ha>
                 LAYOUT: tpch.orders{}                                                                                                           >
                 orderkey := orderkey:bigint:0:REGULAR (1:148)                                                                                   >
                 custkey := custkey:bigint:1:REGULAR (1:148)  

However, in current prune output rule, it will prune the output of k2 from the intersect node, as it's not in the output, hence lead to incorrect result.
Fortunately, currently we only run the prune output rule after ImplementIntersectAndExceptAsUnion rule, which means intersect and except node does not exist, hence we didn't hit this bug.

Impact

Fix a potential correctness issue.

Test Plan

Unit test

Contributor checklist

  • Please make sure your submission complies with our development, formatting, commit message, and attribution guidelines.
  • PR description addresses the issue accurately and concisely. If the change is non-trivial, a GitHub Issue is referenced.
  • Documented new properties (with its default value), SQL syntax, functions, or other functionality.
  • If release notes are required, they follow the release notes guidelines.
  • Adequate tests were added if applicable.
  • CI passed.

Release Notes

Please follow release notes guidelines and fill in the release notes below.

== RELEASE NOTES ==

General Changes
* Fix a potential bug in except and intersect queries. Do not prune unreferenced output in intersect and except nodes.

@feilong-liu feilong-liu requested a review from a team as a code owner November 8, 2023 19:10
@jaystarshot jaystarshot self-assigned this Nov 9, 2023
Copy link
Member

@jaystarshot jaystarshot Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit- The method is called rewriteSetOperationVariableMapping, so Instead of changing the method to not prune, can we use inline and exit early where it should be false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewriteSetOperationVariableMapping is doing two things, 1) prune variables 2) rewrite Map<Variable, List> to ListMultimap<Variable, Variable>. For the intersect and except node, it does not need to be pruned. Ideally we do not even need to rewrite to ListMultimap here, however it needs to change function rewriteSetOperationSubPlans as well, which seems too much change for this fix. After some thought I found the current way to be with minimum change in code and more direct in intention, i.e. skip pruning for these two nodes.

Copy link
Member

@jaystarshot jaystarshot Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another way we can also do is like this

 @Override
        public PlanNode visitIntersect(IntersectNode node, RewriteContext<Set<VariableReferenceExpression>> context)
        {

            Set<VariableReferenceExpression> expectedInputs = new HashSet<>(context.get());
            expectedInputs.addAll(node.getOutputVariables());
            ListMultimap<VariableReferenceExpression, VariableReferenceExpression> rewrittenVariableMapping = rewriteSetOperationVariableMapping(node, expectedInputs);
            ImmutableList<PlanNode> rewrittenSubPlans = rewriteSetOperationSubPlans(node, context, rewrittenVariableMapping);
            return new IntersectNode(node.getSourceLocation(), node.getId(), rewrittenSubPlans, ImmutableList.copyOf(rewrittenVariableMapping.keySet()), fromListMultimap(rewrittenVariableMapping));
        }

Basically add intersects output as expectedInputs and refactor rewriteSetOperationVariableMapping function to take Set as the input. I think this is more in line logically with this optimizer.
Not blocking the current approach just adding a discussion point

@jaystarshot jaystarshot removed their assignment Nov 9, 2023
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perhaps call the argument pruneUnreferencedColumns

@mlyublena
Copy link
Contributor

I guess this bug could lead to correctness issues if we incorrectly prune columns (and thus may collapse rows that aren't duplicates in the extra column but agree on all other columns).
Do you have a test like that? Can we add it to one of our end to end test suites?

@feilong-liu
Copy link
Contributor Author

I guess this bug could lead to correctness issues if we incorrectly prune columns (and thus may collapse rows that aren't duplicates in the extra column but agree on all other columns). Do you have a test like that? Can we add it to one of our end to end test suites?

We already have test cases for them,

assertQuery("SELECT COUNT(*), SUM(2), regionkey FROM (SELECT nationkey, regionkey FROM nation INTERSECT SELECT regionkey, regionkey FROM nation) n GROUP BY regionkey");
and
assertQuery("SELECT COUNT(*), SUM(2), regionkey FROM (SELECT nationkey, regionkey FROM nation EXCEPT SELECT regionkey, regionkey FROM nation) n GROUP BY regionkey HAVING regionkey < 3");

@mlyublena
Copy link
Contributor

I guess this bug could lead to correctness issues if we incorrectly prune columns (and thus may collapse rows that aren't duplicates in the extra column but agree on all other columns). Do you have a test like that? Can we add it to one of our end to end test suites?

We already have test cases for them,

assertQuery("SELECT COUNT(*), SUM(2), regionkey FROM (SELECT nationkey, regionkey FROM nation INTERSECT SELECT regionkey, regionkey FROM nation) n GROUP BY regionkey");

and

assertQuery("SELECT COUNT(*), SUM(2), regionkey FROM (SELECT nationkey, regionkey FROM nation EXCEPT SELECT regionkey, regionkey FROM nation) n GROUP BY regionkey HAVING regionkey < 3");

but do we have a test that had wrong results before and is now fixed? these may have been accidentally correct

@feilong-liu
Copy link
Contributor Author

I guess this bug could lead to correctness issues if we incorrectly prune columns (and thus may collapse rows that aren't duplicates in the extra column but agree on all other columns). Do you have a test like that? Can we add it to one of our end to end test suites?

We already have test cases for them,

assertQuery("SELECT COUNT(*), SUM(2), regionkey FROM (SELECT nationkey, regionkey FROM nation INTERSECT SELECT regionkey, regionkey FROM nation) n GROUP BY regionkey");

and

assertQuery("SELECT COUNT(*), SUM(2), regionkey FROM (SELECT nationkey, regionkey FROM nation EXCEPT SELECT regionkey, regionkey FROM nation) n GROUP BY regionkey HAVING regionkey < 3");

but do we have a test that had wrong results before and is now fixed? these may have been accidentally correct

This will not fail, as currently the earliest run of PruneUnreferencedOutputs is after ImplementIntersectAndExceptAsUnion which rewrites except and intersect to union+aggregation, and this is why the bug is never triggered. But I do verify that when we have PruneUnreferencedOutputs before ImplementIntersectAndExceptAsUnion this test will break.

@kaikalur
Copy link
Contributor

I actually wonder if we should change intersect and execpt to use join/left join - aggregations are not very good for optimizations. Joins are better

Copy link
Contributor

@ajaygeorge ajaygeorge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stamping since it is already reviewed.

We should not prune output of intersect and except nodes.
@feilong-liu feilong-liu merged commit 374bec5 into prestodb:master Nov 16, 2023
@feilong-liu feilong-liu deleted the fix_prune_output branch November 16, 2023 21:26
@wanglinsong wanglinsong mentioned this pull request Feb 12, 2024
64 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants