-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16046][DOCS] Aggregations in the Spark SQL programming guide #16329
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b2c08d5
0ee7c80
87a68bd
0b17e13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.examples.sql; | ||
|
|
||
| // $example on:typed_custom_aggregation$ | ||
| import java.io.Serializable; | ||
|
|
||
| import org.apache.spark.sql.Dataset; | ||
| import org.apache.spark.sql.Encoder; | ||
| import org.apache.spark.sql.Encoders; | ||
| import org.apache.spark.sql.SparkSession; | ||
| import org.apache.spark.sql.TypedColumn; | ||
| import org.apache.spark.sql.expressions.Aggregator; | ||
| // $example off:typed_custom_aggregation$ | ||
|
|
||
| public class JavaUserDefinedTypedAggregation { | ||
|
|
||
| // $example on:typed_custom_aggregation$ | ||
| public static class Employee implements Serializable { | ||
| private String name; | ||
| private long salary; | ||
|
|
||
| // Constructors, getters, setters... | ||
| // $example off:typed_custom_aggregation$ | ||
| public String getName() { | ||
| return name; | ||
| } | ||
|
|
||
| public void setName(String name) { | ||
| this.name = name; | ||
| } | ||
|
|
||
| public long getSalary() { | ||
| return salary; | ||
| } | ||
|
|
||
| public void setSalary(long salary) { | ||
| this.salary = salary; | ||
| } | ||
| // $example on:typed_custom_aggregation$ | ||
| } | ||
|
|
||
| public static class Average implements Serializable { | ||
| private long sum; | ||
| private long count; | ||
|
|
||
| // Constructors, getters, setters... | ||
| // $example off:typed_custom_aggregation$ | ||
| public Average() { | ||
| } | ||
|
|
||
| public Average(long sum, long count) { | ||
| this.sum = sum; | ||
| this.count = count; | ||
| } | ||
|
|
||
| public long getSum() { | ||
| return sum; | ||
| } | ||
|
|
||
| public void setSum(long sum) { | ||
| this.sum = sum; | ||
| } | ||
|
|
||
| public long getCount() { | ||
| return count; | ||
| } | ||
|
|
||
| public void setCount(long count) { | ||
| this.count = count; | ||
| } | ||
| // $example on:typed_custom_aggregation$ | ||
| } | ||
|
|
||
| public static class MyAverage extends Aggregator<Employee, Average, Double> { | ||
| // A zero value for this aggregation. Should satisfy the property that any b + zero = b | ||
| public Average zero() { | ||
| return new Average(0L, 0L); | ||
| } | ||
| // Combine two values to produce a new value. For performance, the function may modify `buffer` | ||
| // and return it instead of constructing a new object | ||
| public Average reduce(Average buffer, Employee employee) { | ||
| long newSum = buffer.getSum() + employee.getSalary(); | ||
| long newCount = buffer.getCount() + 1; | ||
| buffer.setSum(newSum); | ||
| buffer.setCount(newCount); | ||
| return buffer; | ||
| } | ||
| // Merge two intermediate values | ||
| public Average merge(Average b1, Average b2) { | ||
| long mergedSum = b1.getSum() + b2.getSum(); | ||
| long mergedCount = b1.getCount() + b2.getCount(); | ||
| b1.setSum(mergedSum); | ||
| b1.setCount(mergedCount); | ||
| return b1; | ||
| } | ||
| // Transform the output of the reduction | ||
| public Double finish(Average reduction) { | ||
| return ((double) reduction.getSum()) / reduction.getCount(); | ||
| } | ||
| // Specifies the Encoder for the intermediate value type | ||
| public Encoder<Average> bufferEncoder() { | ||
| return Encoders.bean(Average.class); | ||
| } | ||
| // Specifies the Encoder for the final output value type | ||
| public Encoder<Double> outputEncoder() { | ||
| return Encoders.DOUBLE(); | ||
| } | ||
| } | ||
| // $example off:typed_custom_aggregation$ | ||
|
|
||
| public static void main(String[] args) { | ||
| SparkSession spark = SparkSession | ||
| .builder() | ||
| .appName("Java Spark SQL user-defined Datasets aggregation example") | ||
| .getOrCreate(); | ||
|
|
||
| // $example on:typed_custom_aggregation$ | ||
| Encoder<Employee> employeeEncoder = Encoders.bean(Employee.class); | ||
| String path = "examples/src/main/resources/employees.json"; | ||
| Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder); | ||
| ds.show(); | ||
| // +-------+------+ | ||
| // | name|salary| | ||
| // +-------+------+ | ||
| // |Michael| 3000| | ||
| // | Andy| 4500| | ||
| // | Justin| 3500| | ||
| // | Berta| 4000| | ||
| // +-------+------+ | ||
|
|
||
| MyAverage myAverage = new MyAverage(); | ||
| // Convert the function to a `TypedColumn` and give it a name | ||
| TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary"); | ||
| Dataset<Double> result = ds.select(averageSalary); | ||
| result.show(); | ||
| // +--------------+ | ||
| // |average_salary| | ||
| // +--------------+ | ||
| // | 3750.0| | ||
| // +--------------+ | ||
| // $example off:typed_custom_aggregation$ | ||
| spark.stop(); | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.spark.examples.sql; | ||
|
|
||
| // $example on:untyped_custom_aggregation$ | ||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
|
|
||
| import org.apache.spark.sql.Dataset; | ||
| import org.apache.spark.sql.Row; | ||
| import org.apache.spark.sql.SparkSession; | ||
| import org.apache.spark.sql.expressions.MutableAggregationBuffer; | ||
| import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; | ||
| import org.apache.spark.sql.types.DataType; | ||
| import org.apache.spark.sql.types.DataTypes; | ||
| import org.apache.spark.sql.types.StructField; | ||
| import org.apache.spark.sql.types.StructType; | ||
| // $example off:untyped_custom_aggregation$ | ||
|
|
||
| public class JavaUserDefinedUntypedAggregation { | ||
|
|
||
| // $example on:untyped_custom_aggregation$ | ||
| public static class MyAverage extends UserDefinedAggregateFunction { | ||
|
||
|
|
||
| private StructType inputSchema; | ||
| private StructType bufferSchema; | ||
|
|
||
| public MyAverage() { | ||
| List<StructField> inputFields = new ArrayList<>(); | ||
| inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true)); | ||
| inputSchema = DataTypes.createStructType(inputFields); | ||
|
|
||
| List<StructField> bufferFields = new ArrayList<>(); | ||
| bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true)); | ||
| bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true)); | ||
| bufferSchema = DataTypes.createStructType(bufferFields); | ||
| } | ||
| // Data types of input arguments of this aggregate function | ||
| public StructType inputSchema() { | ||
| return inputSchema; | ||
| } | ||
| // Data types of values in the aggregation buffer | ||
| public StructType bufferSchema() { | ||
| return bufferSchema; | ||
| } | ||
| // The data type of the returned value | ||
| public DataType dataType() { | ||
| return DataTypes.DoubleType; | ||
| } | ||
| // Whether this function always returns the same output on the identical input | ||
| public boolean deterministic() { | ||
| return true; | ||
| } | ||
| // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to | ||
| // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides | ||
| // the opportunity to update its values. Note that arrays and maps inside the buffer are still | ||
| // immutable. | ||
| public void initialize(MutableAggregationBuffer buffer) { | ||
| buffer.update(0, 0L); | ||
| buffer.update(1, 0L); | ||
| } | ||
| // Updates the given aggregation buffer `buffer` with new input data from `input` | ||
| public void update(MutableAggregationBuffer buffer, Row input) { | ||
| if (!input.isNullAt(0)) { | ||
| long updatedSum = buffer.getLong(0) + input.getLong(0); | ||
| long updatedCount = buffer.getLong(1) + 1; | ||
| buffer.update(0, updatedSum); | ||
| buffer.update(1, updatedCount); | ||
| } | ||
| } | ||
| // Merges two aggregation buffers and stores the updated buffer values back to `buffer1` | ||
| public void merge(MutableAggregationBuffer buffer1, Row buffer2) { | ||
| long mergedSum = buffer1.getLong(0) + buffer2.getLong(0); | ||
| long mergedCount = buffer1.getLong(1) + buffer2.getLong(1); | ||
| buffer1.update(0, mergedSum); | ||
| buffer1.update(1, mergedCount); | ||
| } | ||
| // Calculates the final result | ||
| public Double evaluate(Row buffer) { | ||
| return ((double) buffer.getLong(0)) / buffer.getLong(1); | ||
| } | ||
| } | ||
| // $example off:untyped_custom_aggregation$ | ||
|
|
||
| public static void main(String[] args) { | ||
| SparkSession spark = SparkSession | ||
| .builder() | ||
| .appName("Java Spark SQL user-defined DataFrames aggregation example") | ||
| .getOrCreate(); | ||
|
|
||
| // $example on:untyped_custom_aggregation$ | ||
| // Register the function to access it | ||
| spark.udf().register("myAverage", new MyAverage()); | ||
|
|
||
| Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json"); | ||
| df.createOrReplaceTempView("employees"); | ||
| df.show(); | ||
| // +-------+------+ | ||
| // | name|salary| | ||
| // +-------+------+ | ||
| // |Michael| 3000| | ||
| // | Andy| 4500| | ||
| // | Justin| 3500| | ||
| // | Berta| 4000| | ||
| // +-------+------+ | ||
|
|
||
| Dataset<Row> result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees"); | ||
| result.show(); | ||
| // +--------------+ | ||
| // |average_salary| | ||
| // +--------------+ | ||
| // | 3750.0| | ||
| // +--------------+ | ||
| // $example off:untyped_custom_aggregation$ | ||
|
|
||
| spark.stop(); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| {"name":"Michael", "salary":3000} | ||
| {"name":"Andy", "salary":4500} | ||
| {"name":"Justin", "salary":3500} | ||
| {"name":"Berta", "salary":4000} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this meant to be
MyAverage?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srowen
Averageis a Java bean that holds current sum and count. It is defined earlier. Here it represents a zero value.MyAverage, in turn, is the actual aggregator that accepts instances of theEmployeeclass, stores intermediate results using an instance ofAverage, and producesDoubleas a result.I can rename
MyAveragetoMyAverageAggregatorif this makes things clearer.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad, I read this incorrectly while skimming.