diff --git a/sql/src/main/scala/org/apache/spark/sql/geosparksql/expressions/AggregateFunctions.scala b/sql/src/main/scala/org/apache/spark/sql/geosparksql/expressions/AggregateFunctions.scala index 8d5d9390bd7..0d8415807a9 100644 --- a/sql/src/main/scala/org/apache/spark/sql/geosparksql/expressions/AggregateFunctions.scala +++ b/sql/src/main/scala/org/apache/spark/sql/geosparksql/expressions/AggregateFunctions.scala @@ -17,27 +17,17 @@ package org.apache.spark.sql.geosparksql.expressions import com.vividsolutions.jts.geom.{Coordinate, Geometry, GeometryFactory} -import org.apache.spark.sql.Row -import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} -import org.apache.spark.sql.geosparksql.UDT.GeometryUDT -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator /** - * Return the polygon union of all Polygon in the given column - */ - -class ST_Union_Aggr extends UserDefinedAggregateFunction { - override def inputSchema: StructType = StructType(StructField("Union", new GeometryUDT) :: Nil) - - override def bufferSchema: StructType = StructType( - StructField("Union", new GeometryUDT) :: Nil - ) - - override def dataType: DataType = new GeometryUDT - - override def deterministic: Boolean = true + * traits for creating Aggregate Function + */ - override def initialize(buffer: MutableAggregationBuffer): Unit = { +trait TraitSTAggregateExec{ + val initialGeometry:Geometry = { + // dummy value for initial value(polygon but ) + // any other value is ok. val coordinates: Array[Coordinate] = new Array[Coordinate](5) coordinates(0) = new Coordinate(-999999999, -999999999) coordinates(1) = new Coordinate(-999999999, -999999999) @@ -45,76 +35,62 @@ class ST_Union_Aggr extends UserDefinedAggregateFunction { coordinates(3) = new Coordinate(-999999999, -999999999) coordinates(4) = coordinates(0) val geometryFactory = new GeometryFactory() - buffer(0) = geometryFactory.createPolygon(coordinates) + geometryFactory.createPolygon(coordinates) } + def zero:Geometry = initialGeometry + val serde = ExpressionEncoder[Geometry]() - override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - val accumulateUnion = buffer.getAs[Geometry](0) - val newPolygon = input.getAs[Geometry](0) - if (accumulateUnion.getArea == 0) buffer(0) = newPolygon - else buffer(0) = accumulateUnion.union(newPolygon) - } + def bufferEncoder:ExpressionEncoder[Geometry] = serde + def outputEncoder:ExpressionEncoder[Geometry] = serde + def finish(out:Geometry) :Geometry = out +} - override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - val leftPolygon = buffer1.getAs[Geometry](0) - val rightPolygon = buffer2.getAs[Geometry](0) - if (leftPolygon.getCoordinates()(0).x == -999999999) buffer1(0) = rightPolygon - else if (rightPolygon.getCoordinates()(0).x == -999999999) buffer1(0) = leftPolygon - else buffer1(0) = leftPolygon.union(rightPolygon) - } +/** + * Return the polygon union of all Polygon in the given column + */ +class ST_Union_Aggr extends Aggregator[Geometry,Geometry,Geometry] with TraitSTAggregateExec { - override def evaluate(buffer: Row): Any = { - return buffer.getAs[Geometry](0) + def reduce(buffer:Geometry,input:Geometry) :Geometry = { + if (buffer.equalsExact(initialGeometry)) input + else buffer.union(input) } -} -/** - * Return the envelope boundary of the entire column - */ -class ST_Envelope_Aggr extends UserDefinedAggregateFunction { - // This is the input fields for your aggregate function. - override def inputSchema: org.apache.spark.sql.types.StructType = - StructType(StructField("Envelope", new GeometryUDT) :: Nil) + def merge(buffer1:Geometry,buffer2:Geometry):Geometry = { + if (buffer1.equals(initialGeometry)) buffer2 + else if (buffer2.equals(initialGeometry)) buffer1 + else buffer1.union(buffer2) + } - // This is the internal fields you keep for computing your aggregate. - override def bufferSchema: StructType = StructType( - StructField("Envelope", new GeometryUDT) :: Nil - ) - // This is the output type of your aggregatation function. - override def dataType: DataType = new GeometryUDT +} - override def deterministic: Boolean = true - // This is the initial value for your buffer schema. - override def initialize(buffer: MutableAggregationBuffer): Unit = { - val coordinates: Array[Coordinate] = new Array[Coordinate](5) - coordinates(0) = new Coordinate(-999999999, -999999999) - coordinates(1) = new Coordinate(-999999999, -999999999) - coordinates(2) = new Coordinate(-999999999, -999999999) - coordinates(3) = new Coordinate(-999999999, -999999999) - coordinates(4) = new Coordinate(-999999999, -999999999) - val geometryFactory = new GeometryFactory() - buffer(0) = geometryFactory.createPolygon(coordinates) - //buffer(0) = new GenericArrayData(GeometrySerializer.serialize(geometryFactory.createPolygon(coordinates))) - } +/** + * Return the envelope boundary of the entire column + */ +class ST_Envelope_Aggr extends Aggregator[Geometry,Geometry,Geometry] with TraitSTAggregateExec { - // This is how to update your buffer schema given an input. - override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - val accumulateEnvelope = buffer.getAs[Geometry](0).getEnvelopeInternal - val newEnvelope = input.getAs[Geometry](0).getEnvelopeInternal + def reduce(buffer:Geometry,input:Geometry) :Geometry = { + val accumulateEnvelope = buffer.getEnvelopeInternal + val newEnvelope = input.getEnvelopeInternal val coordinates: Array[Coordinate] = new Array[Coordinate](5) var minX = 0.0 var minY = 0.0 var maxX = 0.0 var maxY = 0.0 - if (accumulateEnvelope.getMinX == -999999999) { + if (accumulateEnvelope.equals(initialGeometry.getEnvelopeInternal)) { // Found the accumulateEnvelope is the initial value minX = newEnvelope.getMinX minY = newEnvelope.getMinY maxX = newEnvelope.getMaxX maxY = newEnvelope.getMaxY } + else if(newEnvelope.equals(initialGeometry.getEnvelopeInternal)){ + minX = accumulateEnvelope.getMinX + minY = accumulateEnvelope.getMinY + maxX = accumulateEnvelope.getMaxX + maxY = accumulateEnvelope.getMaxY + } else { minX = Math.min(accumulateEnvelope.getMinX, newEnvelope.getMinX) minY = Math.min(accumulateEnvelope.getMinY, newEnvelope.getMinY) @@ -127,94 +103,62 @@ class ST_Envelope_Aggr extends UserDefinedAggregateFunction { coordinates(3) = new Coordinate(maxX, minY) coordinates(4) = coordinates(0) val geometryFactory = new GeometryFactory() - buffer(0) = geometryFactory.createPolygon(coordinates) + geometryFactory.createPolygon(coordinates) + } - // This is how to merge two objects with the bufferSchema type. - override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - val leftEnvelope = buffer1.getAs[Geometry](0).getEnvelopeInternal - val rightEnvelope = buffer2.getAs[Geometry](0).getEnvelopeInternal + def merge(buffer1:Geometry,buffer2:Geometry):Geometry ={ + val leftEnvelope = buffer1.getEnvelopeInternal + val rightEnvelope = buffer2.getEnvelopeInternal val coordinates: Array[Coordinate] = new Array[Coordinate](5) var minX = 0.0 var minY = 0.0 var maxX = 0.0 var maxY = 0.0 - if (leftEnvelope.getMinX == -999999999) { - // Found the leftEnvelope is the initial value + if(leftEnvelope.equals(initialGeometry.getEnvelopeInternal)) { minX = rightEnvelope.getMinX minY = rightEnvelope.getMinY maxX = rightEnvelope.getMaxX maxY = rightEnvelope.getMaxY } - else if (rightEnvelope.getMinX == -999999999) { - // Found the rightEnvelope is the initial value + else if(rightEnvelope.equals(initialGeometry.getEnvelopeInternal)){ minX = leftEnvelope.getMinX minY = leftEnvelope.getMinY maxX = leftEnvelope.getMaxX maxY = leftEnvelope.getMaxY } - else { + else{ minX = Math.min(leftEnvelope.getMinX, rightEnvelope.getMinX) minY = Math.min(leftEnvelope.getMinY, rightEnvelope.getMinY) maxX = Math.max(leftEnvelope.getMaxX, rightEnvelope.getMaxX) maxY = Math.max(leftEnvelope.getMaxY, rightEnvelope.getMaxY) } + coordinates(0) = new Coordinate(minX, minY) coordinates(1) = new Coordinate(minX, maxY) coordinates(2) = new Coordinate(maxX, maxY) coordinates(3) = new Coordinate(maxX, minY) coordinates(4) = coordinates(0) val geometryFactory = new GeometryFactory() - buffer1(0) = geometryFactory.createPolygon(coordinates) + geometryFactory.createPolygon(coordinates) } - // This is where you output the final value, given the final value of your bufferSchema. - override def evaluate(buffer: Row): Any = { - return buffer.getAs[Geometry](0) - } + } /** * Return the polygon intersection of all Polygon in the given column */ -class ST_Intersection_Aggr extends UserDefinedAggregateFunction { - override def inputSchema: StructType = StructType(StructField("Intersection", new GeometryUDT) :: Nil) - - override def bufferSchema: StructType = StructType( - StructField("Intersection", new GeometryUDT) :: Nil - ) - - override def dataType: DataType = new GeometryUDT - - override def deterministic: Boolean = true - - override def initialize(buffer: MutableAggregationBuffer): Unit = { - val coordinates: Array[Coordinate] = new Array[Coordinate](5) - coordinates(0) = new Coordinate(-999999999, -999999999) - coordinates(1) = new Coordinate(-999999999, -999999999) - coordinates(2) = new Coordinate(-999999999, -999999999) - coordinates(3) = new Coordinate(-999999999, -999999999) - coordinates(4) = new Coordinate(-999999999, -999999999) - val geometryFactory = new GeometryFactory() - buffer(0) = geometryFactory.createPolygon(coordinates) +class ST_Intersection_Aggr extends Aggregator[Geometry,Geometry,Geometry] with TraitSTAggregateExec { + def reduce(buffer:Geometry,input:Geometry) :Geometry = { + if (buffer.isEmpty) input + else if (buffer.equalsExact(initialGeometry)) input + else buffer.intersection(input) } - - override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { - val accumulateIntersection = buffer.getAs[Geometry](0) - val newPolygon = input.getAs[Geometry](0) - if (accumulateIntersection.getArea == 0) buffer(0) = newPolygon - else buffer(0) = accumulateIntersection.intersection(newPolygon) - } - - override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { - val leftPolygon = buffer1.getAs[Geometry](0) - val rightPolygon = buffer2.getAs[Geometry](0) - if (leftPolygon.getCoordinates()(0).x == -999999999) buffer1(0) = rightPolygon - else if (rightPolygon.getCoordinates()(0).x == -999999999) buffer1(0) = leftPolygon - else buffer1(0) = leftPolygon.intersection(rightPolygon) + def merge(buffer1:Geometry,buffer2:Geometry):Geometry = { + if(buffer1.equalsExact(initialGeometry)) buffer2 + else if(buffer2.equalsExact(initialGeometry)) buffer1 + else buffer1.intersection(buffer2) } +} - override def evaluate(buffer: Row): Any = { - buffer.getAs[Geometry](0) - } -} \ No newline at end of file diff --git a/sql/src/main/scala/org/datasyslab/geosparksql/UDF/Catalog.scala b/sql/src/main/scala/org/datasyslab/geosparksql/UDF/Catalog.scala index 6d2d4e6ad08..a519f6a18f4 100644 --- a/sql/src/main/scala/org/datasyslab/geosparksql/UDF/Catalog.scala +++ b/sql/src/main/scala/org/datasyslab/geosparksql/UDF/Catalog.scala @@ -16,8 +16,9 @@ */ package org.datasyslab.geosparksql.UDF +import com.vividsolutions.jts.geom.Geometry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction +import org.apache.spark.sql.expressions.{Aggregator} import org.apache.spark.sql.geosparksql.expressions._ object Catalog { @@ -76,7 +77,7 @@ object Catalog { ST_IsRing ) - val aggregateExpressions:Seq[UserDefinedAggregateFunction] = Seq( + val aggregateExpressions:Seq[Aggregator[Geometry, Geometry, Geometry]] = Seq( new ST_Union_Aggr, new ST_Envelope_Aggr, new ST_Intersection_Aggr diff --git a/sql/src/main/scala/org/datasyslab/geosparksql/UDF/UdfRegistrator.scala b/sql/src/main/scala/org/datasyslab/geosparksql/UDF/UdfRegistrator.scala index 4a2cb9398b8..d8e7d83581f 100644 --- a/sql/src/main/scala/org/datasyslab/geosparksql/UDF/UdfRegistrator.scala +++ b/sql/src/main/scala/org/datasyslab/geosparksql/UDF/UdfRegistrator.scala @@ -17,7 +17,7 @@ package org.datasyslab.geosparksql.UDF import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.{SQLContext, SparkSession} +import org.apache.spark.sql.{SQLContext, SparkSession, functions} object UdfRegistrator { @@ -27,7 +27,7 @@ object UdfRegistrator { def registerAll(sparkSession: SparkSession): Unit = { Catalog.expressions.foreach(f=>sparkSession.sessionState.functionRegistry.createOrReplaceTempFunction(f.getClass.getSimpleName.dropRight(1),f)) - Catalog.aggregateExpressions.foreach(f=>sparkSession.udf.register(f.getClass.getSimpleName,f)) + Catalog.aggregateExpressions.foreach(f=> sparkSession.udf.register(f.getClass.getSimpleName,functions.udaf(f))) } def dropAll(sparkSession: SparkSession): Unit = { diff --git a/sql/src/test/scala/org/datasyslab/geosparksql/TestBaseScala.scala b/sql/src/test/scala/org/datasyslab/geosparksql/TestBaseScala.scala index 05804b09547..ac0bf887372 100644 --- a/sql/src/test/scala/org/datasyslab/geosparksql/TestBaseScala.scala +++ b/sql/src/test/scala/org/datasyslab/geosparksql/TestBaseScala.scala @@ -42,7 +42,8 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll{ config("spark.kryo.registrator", classOf[GeoSparkKryoRegistrator].getName). master("local[*]").appName("geosparksqlScalaTest") .config("spark.sql.warehouse.dir", warehouseLocation) - .enableHiveSupport().getOrCreate() + .enableHiveSupport() + .getOrCreate() import sparkSession.implicits._ diff --git a/sql/src/test/scala/org/datasyslab/geosparksql/aggregateFunctionTestScala.scala b/sql/src/test/scala/org/datasyslab/geosparksql/aggregateFunctionTestScala.scala index f949a5ef34f..26d1337b932 100644 --- a/sql/src/test/scala/org/datasyslab/geosparksql/aggregateFunctionTestScala.scala +++ b/sql/src/test/scala/org/datasyslab/geosparksql/aggregateFunctionTestScala.scala @@ -27,6 +27,9 @@ package org.datasyslab.geosparksql import com.vividsolutions.jts.geom.{Coordinate, Geometry, GeometryFactory} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.UserDefinedAggregator +import org.apache.spark.sql.functions class aggregateFunctionTestScala extends TestBaseScala { @@ -60,37 +63,37 @@ class aggregateFunctionTestScala extends TestBaseScala { var union = sparkSession.sql("select ST_Union_Aggr(polygondf.polygonshape) from polygondf") assert(union.take(1)(0).get(0).asInstanceOf[Geometry].getArea == 10100) } - } - it("Passed ST_Intersection_aggr") { + it("Passed ST_Intersection_aggr") { - val twoPolygonsAsWktDf = sparkSession.read.textFile(intersectionPolygonInputLocation).toDF("polygon_wkt") - twoPolygonsAsWktDf.createOrReplaceTempView("two_polygons_wkt") - twoPolygonsAsWktDf.show() + val twoPolygonsAsWktDf = sparkSession.read.textFile(intersectionPolygonInputLocation).toDF("polygon_wkt") + twoPolygonsAsWktDf.createOrReplaceTempView("two_polygons_wkt") + twoPolygonsAsWktDf.show() - sparkSession - .sql("select ST_GeomFromWKT(polygon_wkt) as polygon from two_polygons_wkt") - .createOrReplaceTempView("two_polygons") + sparkSession + .sql("select ST_GeomFromWKT(polygon_wkt) as polygon from two_polygons_wkt") + .createOrReplaceTempView("two_polygons") - val intersectionDF = sparkSession.sql("select ST_Intersection_Aggr(polygon) from two_polygons") - intersectionDF.show(false) + val intersectionDF = sparkSession.sql("select ST_Intersection_Aggr(polygon) from two_polygons") + intersectionDF.show(false) - assertResult(0.0034700160226227607)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea) - } + assertResult(0.0034700160226227607)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea) + } - it("Passed ST_Intersection_aggr no intersection gives empty polygon") { + it("Passed ST_Intersection_aggr no intersection gives empty polygon") { - val twoPolygonsAsWktDf = sparkSession.read.textFile(intersectionPolygonNoIntersectionInputLocation).toDF("polygon_wkt") - twoPolygonsAsWktDf.createOrReplaceTempView("two_polygons_no_intersection_wkt") - twoPolygonsAsWktDf.show() + val twoPolygonsAsWktDf = sparkSession.read.textFile(intersectionPolygonNoIntersectionInputLocation).toDF("polygon_wkt") + twoPolygonsAsWktDf.createOrReplaceTempView("two_polygons_no_intersection_wkt") + twoPolygonsAsWktDf.show() - sparkSession - .sql("select ST_GeomFromWKT(polygon_wkt) as polygon from two_polygons_no_intersection_wkt") - .createOrReplaceTempView("two_polygons_no_intersection") + sparkSession + .sql("select ST_GeomFromWKT(polygon_wkt) as polygon from two_polygons_no_intersection_wkt") + .createOrReplaceTempView("two_polygons_no_intersection") - val intersectionDF = sparkSession.sql("select ST_Intersection_Aggr(polygon) from two_polygons_no_intersection") - intersectionDF.show(false) + val intersectionDF = sparkSession.sql("select ST_Intersection_Aggr(polygon) from two_polygons_no_intersection") + intersectionDF.show(false) - assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea) + assertResult(0.0)(intersectionDF.take(1)(0).get(0).asInstanceOf[Geometry].getArea) + } } }