Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// xsbt clean unidoc previewSite
// xsbt clean unidoc ghpagesPushSite
// xsbt -Dsbt.global.base=/home/eje/.sbt/sonatype +publish
// xsbt +publish
// make sure sparkVersion and pythonVersion are set as you want them prior to +publish

import scala.sys.process._
Expand All @@ -9,9 +9,9 @@ name := "isarn-sketches-spark"

organization := "org.isarnproject"

val packageVersion = "0.3.1"
val packageVersion = "0.4.0-SNAPSHOT"

val sparkVersion = "2.2.2"
val sparkVersion = "2.4.0"

val pythonVersion = "2.7"

Expand All @@ -29,6 +29,12 @@ crossScalaVersions := Seq("2.11.12") // scala 2.12 when spark supports it

pomIncludeRepository := { _ => false }

//isSnapshot := true

//publishConfiguration := publishConfiguration.value.withOverwrite(true)

//publishLocalConfiguration := publishLocalConfiguration.value.withOverwrite(true)

publishMavenStyle := true

publishTo := {
Expand Down Expand Up @@ -60,7 +66,7 @@ developers := List(
)

libraryDependencies ++= Seq(
"org.isarnproject" %% "isarn-sketches" % "0.1.2",
"org.isarnproject" % "isarn-sketches-java" % "0.2.2-LOCAL",
"org.apache.spark" %% "spark-core" % sparkVersion % Provided,
"org.apache.spark" %% "spark-sql" % sparkVersion % Provided,
"org.apache.spark" %% "spark-mllib" % sparkVersion % Provided,
Expand All @@ -75,7 +81,7 @@ initialCommands in console := """
|import org.apache.spark.SparkContext._
|import org.apache.spark.rdd.RDD
|import org.apache.spark.ml.linalg.Vectors
|import org.isarnproject.sketches.TDigest
|import org.isarnproject.sketches.java.TDigest
|import org.isarnproject.sketches.udaf._
|import org.apache.spark.isarnproject.sketches.udt._
|val initialConf = new SparkConf().setAppName("repl").set("spark.serializer", "org.apache.spark.serializer.KryoSerializer").set("spark.kryoserializer.buffer", "16mb")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
import org.isarnproject.sketches.TDigest
import org.isarnproject.sketches.tdmap.TDigestMap
import org.isarnproject.sketches.java.TDigest

import java.util.Arrays

/** A type for receiving the results of deserializing [[TDigestUDT]].
* The payload is the tdigest member field, holding a TDigest object.
Expand All @@ -29,7 +30,7 @@ import org.isarnproject.sketches.tdmap.TDigestMap
* @param tdigest The TDigest payload, which does the actual sketching.
*/
@SQLUserDefinedType(udt = classOf[TDigestUDT])
case class TDigestSQL(tdigest: TDigest)
case class TDigestSQL(tdigest: TDigest, nSer: Int, nDeSer: Int)

/** A UserDefinedType for serializing and deserializing [[TDigestSQL]] structures during UDAF
* aggregations.
Expand All @@ -55,39 +56,41 @@ class TDigestUDT extends UserDefinedType[TDigestSQL] {
def sqlType: DataType = StructType(
StructField("delta", DoubleType, false) ::
StructField("maxDiscrete", IntegerType, false) ::
StructField("nclusters", IntegerType, false) ::
StructField("clustX", ArrayType(DoubleType, false), false) ::
StructField("clustM", ArrayType(DoubleType, false), false) ::
StructField("nSer", IntegerType, false) ::
StructField("nDeSer", IntegerType, false) ::
Nil)

def serialize(tdsql: TDigestSQL): Any = serializeTD(tdsql.tdigest)

def deserialize(datum: Any): TDigestSQL = TDigestSQL(deserializeTD(datum))

private[sketches] def serializeTD(td: TDigest): InternalRow = {
val TDigest(delta, maxDiscrete, nclusters, clusters) = td
val row = new GenericInternalRow(5)
row.setDouble(0, delta)
row.setInt(1, maxDiscrete)
row.setInt(2, nclusters)
val clustX = clusters.keys.toArray
val clustM = clusters.values.toArray
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustX))
row.update(4, UnsafeArrayData.fromPrimitiveArray(clustM))
def serialize(tdsql: TDigestSQL): Any = serializeTD(tdsql)

def deserialize(datum: Any): TDigestSQL = deserializeTD(datum)

private[sketches] def serializeTD(tdsql: TDigestSQL): InternalRow = {
val td = tdsql.tdigest
//println(s"mass= ${td.mass()}")
//if (td.mass() >= 5) throw new Exception("boo!")
val row = new GenericInternalRow(6)
row.setDouble(0, td.getCompression())
row.setInt(1, td.getMaxDiscrete())
val clustX = Arrays.copyOf(td.getCentUnsafe(), td.size())
val clustM = Arrays.copyOf(td.getMassUnsafe(), td.size())
row.update(2, UnsafeArrayData.fromPrimitiveArray(clustX))
row.update(3, UnsafeArrayData.fromPrimitiveArray(clustM))
row.setInt(4, tdsql.nSer + 1)
row.setInt(5, tdsql.nDeSer)
row
}

private[sketches] def deserializeTD(datum: Any): TDigest = datum match {
case row: InternalRow =>
require(row.numFields == 5, s"expected row length 5, got ${row.numFields}")
private[sketches] def deserializeTD(datum: Any): TDigestSQL = datum match {
case row: InternalRow if (row.numFields == 6) =>
val delta = row.getDouble(0)
val maxDiscrete = row.getInt(1)
val nclusters = row.getInt(2)
val clustX = row.getArray(3).toDoubleArray()
val clustM = row.getArray(4).toDoubleArray()
val clusters = clustX.zip(clustM)
.foldLeft(TDigestMap.empty) { case (td, e) => td + e }
TDigest(delta, maxDiscrete, nclusters, clusters)
val clustX = row.getArray(2).toDoubleArray()
val clustM = row.getArray(3).toDoubleArray()
val sz = clustX.length
val td = new TDigest(delta, maxDiscrete, Arrays.copyOf(clustX, sz), Arrays.copyOf(clustM, sz))
TDigestSQL(td, row.getInt(4), row.getInt(5) + 1)
case u => throw new Exception(s"failed to deserialize: $u")
}
}
Expand Down Expand Up @@ -140,11 +143,11 @@ class TDigestArrayUDT extends UserDefinedType[TDigestArraySQL] {
def serialize(tdasql: TDigestArraySQL): Any = {
val row = new GenericInternalRow(5)
val tda: Array[TDigest] = tdasql.tdigests
val delta = if (tda.isEmpty) 0.0 else tda.head.delta
val maxDiscrete = if (tda.isEmpty) 0 else tda.head.maxDiscrete
val clustS = tda.map(_.nclusters)
val clustX = tda.flatMap(_.clusters.keys)
val clustM = tda.flatMap(_.clusters.values)
val delta = if (tda.isEmpty) 0.0 else tda.head.getCompression()
val maxDiscrete = if (tda.isEmpty) 0 else tda.head.getMaxDiscrete()
val clustS = tda.map(_.size())
val clustX = tda.flatMap(_.getCentUnsafe())
val clustM = tda.flatMap(_.getMassUnsafe())
row.setDouble(0, delta)
row.setInt(1, maxDiscrete)
row.update(2, UnsafeArrayData.fromPrimitiveArray(clustS))
Expand All @@ -165,8 +168,7 @@ class TDigestArrayUDT extends UserDefinedType[TDigestArraySQL] {
val tda = clustS.map { nclusters =>
val x = clustX.slice(beg, beg + nclusters)
val m = clustM.slice(beg, beg + nclusters)
val clusters = x.zip(m).foldLeft(TDigestMap.empty) { case (td, e) => td + e }
val td = TDigest(delta, maxDiscrete, nclusters, clusters)
val td = new TDigest(delta, maxDiscrete, x, m)
beg += nclusters
td
}
Expand Down
14 changes: 7 additions & 7 deletions src/main/scala/org/isarnproject/pipelines/TDigestFI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row

import org.isarnproject.sketches.TDigest
import org.isarnproject.sketches.java.TDigest
import org.apache.spark.isarnproject.sketches.udt._
import org.isarnproject.sketches.udaf._

Expand Down Expand Up @@ -62,7 +62,7 @@ package params {
*/
final val delta: DoubleParam =
new DoubleParam(this, "delta", "t-digest compression (> 0)", ParamValidators.gt(0.0))
setDefault(delta, org.isarnproject.sketches.TDigest.deltaDefault)
setDefault(delta, TDigest.COMPRESSION_DEFAULT)
final def getDelta: Double = $(delta)
final def setDelta(value: Double): this.type = set(delta, value)

Expand Down Expand Up @@ -280,13 +280,13 @@ class TDigestFI(override val uid: String) extends Estimator[TDigestFIModel] with
case v: MLSparse =>
var jBeg = 0
v.foreachActive((j, x) => {
for { k <- jBeg until j } { td(k) += 0.0 }
td(j) += x
for { k <- jBeg until j } { td(k).update(0.0) }
td(j).update(x)
jBeg = j + 1
})
for { k <- jBeg until v.size } { td(k) += 0.0 }
for { k <- jBeg until v.size } { td(k).update(0.0) }
case _ =>
for { j <- 0 until fv.size } { td(j) += fv(j) }
for { j <- 0 until fv.size } { td(j).update(fv(j)) }
}
td
},
Expand All @@ -298,7 +298,7 @@ class TDigestFI(override val uid: String) extends Estimator[TDigestFIModel] with
} else {
require(td1.length == td2.length, "mismatched t-digest arrays")
for { j <- 0 until td1.length } {
td1(j) ++= td2(j)
td1(j) = TDigest.merge(td1(j), td2(j))
}
td1
})
Expand Down
56 changes: 38 additions & 18 deletions src/main/scala/org/isarnproject/sketches/udaf/TDigestUDAF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,28 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package org.isarnproject.sketches.udaf
package org.apache.spark.sql.types {

import scala.math.Ordering
import scala.reflect.runtime.universe.typeTag

import org.isarnproject.sketches.java.TDigest
/*
class TDigestType private() extends AtomicType {
private[sql] type InternalType =
}
*/
}

package org.isarnproject.sketches.udaf {

import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row

import org.isarnproject.sketches.TDigest
import org.isarnproject.sketches.java.TDigest

import org.apache.spark.isarnproject.sketches.udt._

Expand Down Expand Up @@ -52,17 +65,22 @@ case class TDigestUDAF[N](deltaV: Double, maxDiscreteV: Int)(implicit
def dataType: DataType = TDigestUDT

def initialize(buf: MutableAggregationBuffer): Unit = {
buf(0) = TDigestSQL(TDigest.empty(deltaV, maxDiscreteV))
buf(0) = TDigestSQL(TDigest.empty(deltaV, maxDiscreteV), 0, 0)
}

def update(buf: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buf(0) = TDigestSQL(buf.getAs[TDigestSQL](0).tdigest + num.toDouble(input.getAs[N](0)))
val tdsql = buf.getAs[TDigestSQL](0)
val td = tdsql.tdigest
td.update(num.toDouble(input.getAs[N](0)))
buf(0) = TDigestSQL(td, tdsql.nSer, tdsql.nDeSer)
}
}

def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = {
buf1(0) = TDigestSQL(buf1.getAs[TDigestSQL](0).tdigest ++ buf2.getAs[TDigestSQL](0).tdigest)
val tdsql1 = buf1.getAs[TDigestSQL](0)
val tdsql2 = buf2.getAs[TDigestSQL](0)
buf1(0) = TDigestSQL(TDigest.merge(tdsql1.tdigest, buf2.getAs[TDigestSQL](0).tdigest), tdsql1.nSer + tdsql2.nSer, tdsql1.nDeSer + tdsql2.nDeSer)
}

def evaluate(buf: Row): Any = buf.getAs[TDigestSQL](0)
Expand Down Expand Up @@ -92,7 +110,7 @@ abstract class TDigestMultiUDAF extends UserDefinedAggregateFunction {
Array.fill(tds2.length) { TDigest.empty(deltaV, maxDiscreteV) }
}
require(tds1.length == tds2.length)
for { j <- 0 until tds1.length } { tds1(j) ++= tds2(j) }
for { j <- 0 until tds1.length } { tds1(j) = TDigest.merge(tds1(j), tds2(j)) }
buf1(0) = TDigestArraySQL(tds1)
}
}
Expand Down Expand Up @@ -129,13 +147,13 @@ case class TDigestMLVecUDAF(deltaV: Double, maxDiscreteV: Int) extends TDigestMu
case v: org.apache.spark.ml.linalg.SparseVector =>
var jBeg = 0
v.foreachActive((j, x) => {
for { k <- jBeg until j } { tdigests(k) += 0.0 }
tdigests(j) += x
for { k <- jBeg until j } { tdigests(k).update(0.0) }
tdigests(j).update(x)
jBeg = j + 1
})
for { k <- jBeg until vec.size } { tdigests(k) += 0.0 }
for { k <- jBeg until vec.size } { tdigests(k).update(0.0) }
case _ =>
for { j <- 0 until vec.size } { tdigests(j) += vec(j) }
for { j <- 0 until vec.size } { tdigests(j).update(vec(j)) }
}
buf(0) = TDigestArraySQL(tdigests)
}
Expand Down Expand Up @@ -172,13 +190,13 @@ case class TDigestMLLibVecUDAF(deltaV: Double, maxDiscreteV: Int) extends TDiges
case v: org.apache.spark.mllib.linalg.SparseVector =>
var jBeg = 0
v.foreachActive((j, x) => {
for { k <- jBeg until j } { tdigests(k) += 0.0 }
tdigests(j) += x
for { k <- jBeg until j } { tdigests(k).update(0.0) }
tdigests(j).update(x)
jBeg = j + 1
})
for { k <- jBeg until vec.size } { tdigests(k) += 0.0 }
for { k <- jBeg until vec.size } { tdigests(k).update(0.0) }
case _ =>
for { j <- 0 until vec.size } { tdigests(j) += vec(j) }
for { j <- 0 until vec.size } { tdigests(j).update(vec(j)) }
}
buf(0) = TDigestArraySQL(tdigests)
}
Expand Down Expand Up @@ -215,7 +233,7 @@ case class TDigestArrayUDAF[N](deltaV: Double, maxDiscreteV: Int)(implicit
require(tdigests.length == data.length)
var j = 0
for { x <- data } {
if (x != null) tdigests(j) += num.toDouble(x)
if (x != null) tdigests(j).update(num.toDouble(x))
j += 1
}
buf(0) = TDigestArraySQL(tdigests)
Expand Down Expand Up @@ -247,14 +265,14 @@ case class TDigestReduceUDAF(deltaV: Double, maxDiscreteV: Int) extends
def dataType: DataType = TDigestUDT

def initialize(buf: MutableAggregationBuffer): Unit = {
buf(0) = TDigestSQL(TDigest.empty(deltaV, maxDiscreteV))
buf(0) = TDigestSQL(TDigest.empty(deltaV, maxDiscreteV), 0, 0)
}

def update(buf: MutableAggregationBuffer, input: Row): Unit = this.merge(buf, input)

def merge(buf1: MutableAggregationBuffer, buf2: Row): Unit = {
if (!buf2.isNullAt(0)) {
buf1(0) = TDigestSQL(buf1.getAs[TDigestSQL](0).tdigest ++ buf2.getAs[TDigestSQL](0).tdigest)
buf1(0) = TDigestSQL(TDigest.merge(buf1.getAs[TDigestSQL](0).tdigest, buf2.getAs[TDigestSQL](0).tdigest), 0, 0)
}
}

Expand Down Expand Up @@ -299,7 +317,7 @@ case class TDigestArrayReduceUDAF(deltaV: Double, maxDiscreteV: Int) extends
Array.fill(tds2.length) { TDigest.empty(deltaV, maxDiscreteV) }
}
require(tds1.length == tds2.length)
for { j <- 0 until tds1.length } { tds1(j) ++= tds2(j) }
for { j <- 0 until tds1.length } { tds1(j) = TDigest.merge(tds1(j), tds2(j)) }
buf1(0) = TDigestArraySQL(tds1)
}
}
Expand Down Expand Up @@ -345,3 +363,5 @@ object pythonBindings {
def tdigestArrayReduceUDAF(delta: Double, maxDiscrete: Int) =
TDigestArrayReduceUDAF(delta, maxDiscrete)
}

}
Loading