Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import java.sql.Connection

import org.apache.spark.SparkConf
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.connector.expressions.aggregate.{Corr, CovarPop, CovarSamp, StddevPop, StddevSamp, VarPop, VarSamp}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -50,9 +52,25 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes
.set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort))
.set("spark.sql.catalog.postgresql.pushDownTableSample", "true")
.set("spark.sql.catalog.postgresql.pushDownAggregate", "true")
.set("spark.sql.catalog.postgresql.pushDownLimit", "true")

override def dataPreparation(conn: Connection): Unit = {}
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE SCHEMA \"test\"").executeUpdate()
conn.prepareStatement(
"CREATE TABLE \"test\".\"employee\" (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," +
" bonus double precision)").executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)")
.executeUpdate()
conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)")
.executeUpdate()
}

override def testUpdateColumnType(tbl: String): Unit = {
sql(s"CREATE TABLE $tbl (ID INTEGER)")
Expand Down Expand Up @@ -80,4 +98,96 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes
}

override def supportsTableSample: Boolean = true

test("scan with aggregate push-down: VAR_POP VAR_SAMP") {
val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM postgresql.test.employee" +
" where dept > 0 group by DePt order by dept")
df.queryExecution.optimizedPlan.collect {
case DataSourceV2ScanRelation(_, scan, _) =>
assert(scan.isInstanceOf[V1ScanWrapper])
val wrapper = scan.asInstanceOf[V1ScanWrapper]
assert(wrapper.pushedDownOperators.aggregation.isDefined)
val aggregationExpressions =
wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
assert(aggregationExpressions(0).isInstanceOf[VarPop])
assert(aggregationExpressions(1).isInstanceOf[VarSamp])
}
val row = df.collect()
assert(row.length === 3)
assert(row(0).length === 2)
assert(row(0).getDouble(0) === 10000d)
assert(row(0).getDouble(1) === 20000d)
assert(row(1).getDouble(0) === 2500d)
assert(row(1).getDouble(1) === 5000d)
assert(row(2).getDouble(0) === 0d)
assert(row(2).isNullAt(1))
}

test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP") {
val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM postgresql.test.employee" +
" where dept > 0 group by DePt order by dept")
df.queryExecution.optimizedPlan.collect {
case DataSourceV2ScanRelation(_, scan, _) =>
assert(scan.isInstanceOf[V1ScanWrapper])
val wrapper = scan.asInstanceOf[V1ScanWrapper]
assert(wrapper.pushedDownOperators.aggregation.isDefined)
val aggregationExpressions =
wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
assert(aggregationExpressions(0).isInstanceOf[StddevPop])
assert(aggregationExpressions(1).isInstanceOf[StddevSamp])
}
val row = df.collect()
assert(row.length === 3)
assert(row(0).length === 2)
assert(row(0).getDouble(0) === 100d)
assert(row(0).getDouble(1) === 141.4213562373095d)
assert(row(1).getDouble(0) === 50d)
assert(row(1).getDouble(1) === 70.71067811865476d)
assert(row(2).getDouble(0) === 0d)
assert(row(2).isNullAt(1))
}

test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") {
val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus) FROM" +
" postgresql.test.employee where dept > 0 group by DePt order by dept")
df.queryExecution.optimizedPlan.collect {
case DataSourceV2ScanRelation(_, scan, _) =>
assert(scan.isInstanceOf[V1ScanWrapper])
val wrapper = scan.asInstanceOf[V1ScanWrapper]
assert(wrapper.pushedDownOperators.aggregation.isDefined)
val aggregationExpressions =
wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
assert(aggregationExpressions(0).isInstanceOf[CovarPop])
assert(aggregationExpressions(1).isInstanceOf[CovarSamp])
}
val row = df.collect()
assert(row.length === 3)
assert(row(0).length === 2)
assert(row(0).getDouble(0) === 10000d)
assert(row(0).getDouble(1) === 20000d)
assert(row(1).getDouble(0) === 2500d)
assert(row(1).getDouble(1) === 5000d)
assert(row(2).getDouble(0) === 0d)
assert(row(2).isNullAt(1))
}

test("scan with aggregate push-down: CORR with filter and group by") {
val df = sql("select CORR(bonus, bonus) FROM postgresql.test.employee where dept > 0" +
" group by DePt order by dept")
df.queryExecution.optimizedPlan.collect {
case DataSourceV2ScanRelation(_, scan, _) =>
assert(scan.isInstanceOf[V1ScanWrapper])
val wrapper = scan.asInstanceOf[V1ScanWrapper]
assert(wrapper.pushedDownOperators.aggregation.isDefined)
val aggregationExpressions =
wrapper.pushedDownOperators.aggregation.get.aggregateExpressions()
assert(aggregationExpressions(0).isInstanceOf[Corr])
}
val row = df.collect()
assert(row.length === 3)
assert(row(0).length === 1)
assert(row(0).getDouble(0) === 1d)
assert(row(1).getDouble(0) === 1d)
assert(row(2).isNullAt(0))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* An aggregate function that returns the mean calculated from values of a group.
*
* @since 3.3.0
*/
@Evolving
public final class Average implements AggregateFunc {
private final NamedReference column;

public Average(NamedReference column) { this.column = column; }

public NamedReference column() { return column; }

@Override
public String toString() { return "AVG(" + column.describe() + ")"; }

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* An aggregate function that returns Pearson coefficient of correlation between
* a set of number pairs.
*
* @since 3.3.0
*/
@Evolving
public final class Corr implements AggregateFunc {
private final NamedReference left;
private final NamedReference right;

public Corr(NamedReference left, NamedReference right) {
this.left = left;
this.right = right;
}

public NamedReference left() { return left; }

public NamedReference right() { return right; }

@Override
public String toString() {
return "CORR(" + left.describe() + "," + right.describe() + ")";
}

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* An aggregate function that returns the population covariance of a set of number pairs.
*
* @since 3.3.0
*/
@Evolving
public final class CovarPop implements AggregateFunc {
private final NamedReference left;
private final NamedReference right;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not x and y?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or we can use left and right consistently

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK


public CovarPop(NamedReference left, NamedReference right) {
this.left = left;
this.right = right;
}

public NamedReference left() { return left; }

public NamedReference right() { return right; }

@Override
public String toString() {
return "COVAR_POP(" + left.describe() + "," + right.describe() + ")";
}

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* An aggregate function that returns the sample covariance of a set of number pairs.
*
* @since 3.3.0
*/
@Evolving
public final class CovarSamp implements AggregateFunc {
private final NamedReference left;
private final NamedReference right;

public CovarSamp(NamedReference left, NamedReference right) {
this.left = left;
this.right = right;
}

public NamedReference left() { return left; }

public NamedReference right() { return right; }

@Override
public String toString() {
return "COVAR_SAMP(" + left.describe() + "," + right.describe() + ")";
}

@Override
public String describe() { return this.toString(); }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connector.expressions.aggregate;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.NamedReference;

/**
* An aggregate function that returns the population standard deviation calculated from
* values of a group.
*
* @since 3.3.0
*/
@Evolving
public final class StddevPop implements AggregateFunc {
private final NamedReference column;

public StddevPop(NamedReference column) { this.column = column; }

public NamedReference column() { return column; }

@Override
public String toString() { return "STDDEV_POP(" + column.describe() + ")"; }

@Override
public String describe() { return this.toString(); }
}
Loading