Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def groupByHash(tests: Seq[TestDefinition]): Seq[Tests.Group] = {
case (i, groupTests) =>
val options = ForkOptions()
.withRunJVMOptions(Vector("-Dspark.ui.enabled=false", "-Xmx1024m"))

Group(i.toString, groupTests, SubProcess(options))
}
.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ import org.yaml.snakeyaml.Yaml
import io.projectglow.SparkShim._
import io.projectglow.common.WithUtils
import io.projectglow.sql.expressions._
import io.projectglow.sql.optimizer.{ReplaceExpressionsRule, ResolveAggregateFunctionsRule, ResolveExpandStructRule}
import io.projectglow.sql.optimizer.{ReplaceExpressionsRule, ResolveAggregateFunctionsRule, ResolveExpandStructRule, ResolveGenotypeFields}

// TODO(hhd): Spark 3.0 allows extensions to register functions. After Spark 3.0 is released,
// we should move all extensions into this class.
class GlowSQLExtensions extends (SparkSessionExtensions => Unit) {
val resolutionRules: Seq[Rule[LogicalPlan]] =
Seq(ReplaceExpressionsRule, ResolveAggregateFunctionsRule, ResolveExpandStructRule)
Seq(
ReplaceExpressionsRule,
ResolveAggregateFunctionsRule,
ResolveExpandStructRule,
ResolveGenotypeFields)
val optimizations: Seq[Rule[LogicalPlan]] = Seq()

def apply(extensions: SparkSessionExtensions): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkEnv
import org.apache.spark.sql.{AnalysisException, SQLUtils}
import org.apache.spark.sql.SQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, Literal, UnaryExpression, Unevaluable}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.common.{GlowLogging, VariantSchemas}
import io.projectglow.sql.util.{ExpectsGenotypeFields, Rewrite}
import io.projectglow.sql.util.{ExpectsGenotypeFields, GenotypeInfo, Rewrite}

case class SampleSummaryStatsState(var sampleId: String, var momentAggState: MomentAggState) {
def this() = this(null, null) // need 0-arg constructor for serialization
Expand All @@ -48,6 +48,7 @@ case class SampleSummaryStatsState(var sampleId: String, var momentAggState: Mom
case class PerSampleSummaryStatistics(
genotypes: Expression,
field: Expression,
genotypeInfo: Option[GenotypeInfo] = None,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[mutable.ArrayBuffer[SampleSummaryStatsState]]
Expand All @@ -66,7 +67,7 @@ case class PerSampleSummaryStatistics(
}

override def genotypesExpr: Expression = genotypes
override def genotypeFieldsRequired: Seq[StructField] = {
override def requiredGenotypeFields: Seq[StructField] = {
if (!field.foldable || field.dataType != StringType) {
throw SQLUtils.newAnalysisException("Field must be foldable string")
}
Expand All @@ -81,7 +82,11 @@ case class PerSampleSummaryStatistics(
}

override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField)
private lazy val hasSampleIds = optionalFieldIndices(0) != -1

override def withGenotypeInfo(genotypeInfo: GenotypeInfo): PerSampleSummaryStatistics = {
copy(genotypeInfo = Some(genotypeInfo))
}
private lazy val hasSampleIds = getGenotypeInfo.optionalFieldIndices(0) != -1

override def createAggregationBuffer(): ArrayBuffer[SampleSummaryStatsState] = {
mutable.ArrayBuffer[SampleSummaryStatsState]()
Expand All @@ -100,22 +105,22 @@ case class PerSampleSummaryStatistics(
}

private lazy val updateStateFn: (MomentAggState, InternalRow) => Unit = {
genotypeFieldsRequired.head.dataType match {
requiredGenotypeFields.head.dataType match {
case FloatType =>
(state, genotype) => {
state.update(genotype.getFloat(genotypeFieldIndices(0)))
state.update(genotype.getFloat(getGenotypeInfo.requiredFieldIndices(0)))
}
case DoubleType =>
(state, genotype) => {
state.update(genotype.getDouble(genotypeFieldIndices(0)))
state.update(genotype.getDouble(getGenotypeInfo.requiredFieldIndices(0)))
}
case IntegerType =>
(state, genotype) => {
state.update(genotype.getInt(genotypeFieldIndices(0)))
state.update(genotype.getInt(getGenotypeInfo.requiredFieldIndices(0)))
}
case LongType =>
(state, genotype) => {
state.update(genotype.getLong(genotypeFieldIndices(0)))
state.update(genotype.getLong(getGenotypeInfo.requiredFieldIndices(0)))
}
}
}
Expand All @@ -132,17 +137,17 @@ case class PerSampleSummaryStatistics(
if (i >= buffer.size) {
val sampleId = if (hasSampleIds) {
genotypesArray
.getStruct(buffer.size, genotypeStructSize)
.getString(optionalFieldIndices(0))
.getStruct(buffer.size, getGenotypeInfo.size)
.getString(getGenotypeInfo.optionalFieldIndices(0))
} else {
null
}
buffer.append(SampleSummaryStatsState(sampleId, MomentAggState()))
}

val struct = genotypesArray.getStruct(i, genotypeStructSize)
if (!struct.isNullAt(genotypeFieldIndices(0))) {
updateStateFn(buffer(i).momentAggState, genotypesArray.getStruct(i, genotypeStructSize))
val struct = genotypesArray.getStruct(i, getGenotypeInfo.size)
if (!struct.isNullAt(getGenotypeInfo.requiredFieldIndices(0))) {
updateStateFn(buffer(i).momentAggState, genotypesArray.getStruct(i, getGenotypeInfo.size))
}
i += 1
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import io.projectglow.common.{GlowLogging, VariantSchemas}
import io.projectglow.sql.util.ExpectsGenotypeFields
import io.projectglow.sql.util.{ExpectsGenotypeFields, GenotypeInfo}
import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
Expand All @@ -43,24 +43,29 @@ case class CallSummaryStats(
genotypes: Expression,
refAllele: Expression,
altAlleles: Expression,
genotypeInfo: Option[GenotypeInfo],
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends TypedImperativeAggregate[mutable.ArrayBuffer[SampleCallStats]]
with ExpectsGenotypeFields
with GlowLogging {

def this(genotypes: Expression, refAllele: Expression, altAlleles: Expression) = {
this(genotypes, refAllele, altAlleles, 0, 0)
this(genotypes, refAllele, altAlleles, None, 0, 0)
}

override def genotypesExpr: Expression = genotypes

override def genotypeFieldsRequired: Seq[StructField] = {
override def requiredGenotypeFields: Seq[StructField] = {
Seq(VariantSchemas.callsField)
}

override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField)
private lazy val hasSampleIds = optionalFieldIndices(0) != -1

override def withGenotypeInfo(genotypeInfo: GenotypeInfo): CallSummaryStats = {
copy(genotypeInfo = Some(genotypeInfo))
}
private lazy val hasSampleIds = getGenotypeInfo.optionalFieldIndices(0) != -1
override def children: Seq[Expression] = Seq(genotypes, refAllele, altAlleles)
override def nullable: Boolean = false

Expand Down Expand Up @@ -116,12 +121,12 @@ case class CallSummaryStats(
var j = 0
while (j < genotypesArr.numElements()) {
val struct = genotypesArr
.getStruct(j, genotypeStructSize)
.getStruct(j, getGenotypeInfo.size)

// Make sure the buffer has an entry for this sample
if (j >= buffer.size) {
val sampleId = if (hasSampleIds) {
struct.getUTF8String(optionalFieldIndices(0))
struct.getUTF8String(getGenotypeInfo.optionalFieldIndices(0))
} else {
null
}
Expand All @@ -130,7 +135,7 @@ case class CallSummaryStats(

val stats = buffer(j)
val calls = struct
.getStruct(genotypeFieldIndices.head, 2)
.getStruct(getGenotypeInfo.requiredFieldIndices.head, 2)
.getArray(0)
var k = 0
var isUncalled = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
import org.apache.spark.sql.types._

import io.projectglow.common.{GlowLogging, VariantSchemas}
import io.projectglow.sql.util.{ExpectsGenotypeFields, LeveneHaldane, Rewrite}
import io.projectglow.sql.util.{ExpectsGenotypeFields, GenotypeInfo, LeveneHaldane, Rewrite}

/**
* Contains implementations of QC functions. These implementations are called during both
Expand Down Expand Up @@ -231,7 +231,11 @@ object VariantQcExprs extends GlowLogging {
}
}

case class HardyWeinberg(genotypes: Expression) extends UnaryExpression with ExpectsGenotypeFields {
case class HardyWeinberg(genotypes: Expression, genotypeInfo: Option[GenotypeInfo])
extends UnaryExpression
with ExpectsGenotypeFields {
def this(genotypes: Expression) = this(genotypes, None)

override def dataType: DataType =
StructType(
Seq(
Expand All @@ -242,24 +246,34 @@ case class HardyWeinberg(genotypes: Expression) extends UnaryExpression with Exp

override def genotypesExpr: Expression = genotypes

override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.callsField)
override def requiredGenotypeFields: Seq[StructField] = Seq(VariantSchemas.callsField)

override def withGenotypeInfo(genotypeInfo: GenotypeInfo): HardyWeinberg = {
copy(genotypeInfo = Some(genotypeInfo))
}

override def child: Expression = genotypes

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val fn = "io.projectglow.sql.expressions.VariantQcExprs.hardyWeinberg"
nullSafeCodeGen(ctx, ev, calls => {
s"""
|${ev.value} = $fn($calls, $genotypeStructSize, ${genotypeFieldIndices.head});
nullSafeCodeGen(
ctx,
ev,
calls => {
s"""
|${ev.value} = $fn($calls, ${getGenotypeInfo.size}, ${getGenotypeInfo
.requiredFieldIndices
.head});
""".stripMargin
})
}
)
}

override def nullSafeEval(input: Any): Any = {
VariantQcExprs.hardyWeinberg(
input.asInstanceOf[ArrayData],
genotypeStructSize,
genotypeFieldIndices.head
getGenotypeInfo.size,
getGenotypeInfo.requiredFieldIndices.head
)
}
}
Expand All @@ -270,29 +284,43 @@ object HardyWeinberg {

case class HardyWeinbergStruct(hetFreqHwe: Double, pValueHwe: Double)

case class CallStats(genotypes: Expression) extends UnaryExpression with ExpectsGenotypeFields {
case class CallStats(genotypes: Expression, genotypeInfo: Option[GenotypeInfo])
extends UnaryExpression
with ExpectsGenotypeFields {
def this(genotypes: Expression) = this(genotypes, None)

lazy val dataType: DataType = CallStats.schema

override def genotypesExpr: Expression = genotypes

override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.callsField)
override def requiredGenotypeFields: Seq[StructField] = Seq(VariantSchemas.callsField)

override def withGenotypeInfo(genotypeInfo: GenotypeInfo): CallStats = {
copy(genotypeInfo = Some(genotypeInfo))
}

override def child: Expression = genotypes

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val fn = "io.projectglow.sql.expressions.VariantQcExprs.callStats"
nullSafeCodeGen(ctx, ev, calls => {
s"""
|${ev.value} = $fn($calls, $genotypeStructSize, ${genotypeFieldIndices.head});
nullSafeCodeGen(
ctx,
ev,
calls => {
s"""
|${ev.value} = $fn($calls, ${getGenotypeInfo.size}, ${getGenotypeInfo
.requiredFieldIndices
.head});
""".stripMargin
})
}
)
}

override def nullSafeEval(input: Any): Any = {
VariantQcExprs.callStats(
input.asInstanceOf[ArrayData],
genotypeStructSize,
genotypeFieldIndices.head
getGenotypeInfo.size,
getGenotypeInfo.requiredFieldIndices.head
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.common.VariantSchemas
import io.projectglow.sql.util.ExpectsGenotypeFields
import io.projectglow.sql.util.{ExpectsGenotypeFields, GenotypeInfo}

/**
* Implementations of utility functions for transforming variant representations. These
Expand Down Expand Up @@ -164,32 +164,44 @@ object VariantType {
* of the calls array for the sample at that position if no calls are missing, or -1 if any calls
* are missing.
*/
case class GenotypeStates(genotypes: Expression)
case class GenotypeStates(genotypes: Expression, genotypeInfo: Option[GenotypeInfo])
extends UnaryExpression
with ExpectsGenotypeFields {

def this(genotypes: Expression) = this(genotypes, None)

override def genotypesExpr: Expression = genotypes

override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.callsField)
override def requiredGenotypeFields: Seq[StructField] = Seq(VariantSchemas.callsField)

override def withGenotypeInfo(genotypeInfo: GenotypeInfo): GenotypeStates = {
copy(genotypes, Some(genotypeInfo))
}

override def child: Expression = genotypes

override def dataType: DataType = ArrayType(IntegerType)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val fn = "io.projectglow.sql.expressions.VariantUtilExprs.genotypeStates"
nullSafeCodeGen(ctx, ev, calls => {
s"""
|${ev.value} = $fn($calls, $genotypeStructSize, ${genotypeFieldIndices.head});
nullSafeCodeGen(
ctx,
ev,
calls => {
s"""
|${ev.value} = $fn($calls, ${getGenotypeInfo.size}, ${getGenotypeInfo
.requiredFieldIndices
.head});
""".stripMargin
})
}
)
}

override def nullSafeEval(input: Any): Any = {
VariantUtilExprs.genotypeStates(
input.asInstanceOf[ArrayData],
genotypeStructSize,
genotypeFieldIndices.head
getGenotypeInfo.size,
getGenotypeInfo.requiredFieldIndices.head
)
}
}
Expand Down
Loading