Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,19 @@ case class MomentAggState(
def update(element: Int): Unit = update(element.toDouble)
def update(element: Float): Unit = update(element.toDouble)

/**
* Writes the mean, stdev, min, and max into the input row beginning at the provided offset.
*/
def toInternalRow(row: InternalRow, offset: Int = 0): InternalRow = {
row.update(offset, if (count > 0) mean else null)
row.update(offset + 1, if (count > 0) Math.sqrt(m2 / (count - 1)) else null)
row.update(offset + 2, if (count > 0) min else null)
row.update(offset + 3, if (count > 0) max else null)
row
}

def toInternalRow: InternalRow = {
new GenericInternalRow(
Array(
if (count > 0) mean else null,
if (count > 0) Math.sqrt(m2 / (count - 1)) else null,
if (count > 0) min else null,
if (count > 0) max else null
)
)
toInternalRow(new GenericInternalRow(4))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.common.{GlowLogging, VariantSchemas}
import io.projectglow.sql.util.ExpectsGenotypeFields

Expand All @@ -41,7 +40,8 @@ case class SampleSummaryStatsState(var sampleId: String, var momentAggState: Mom
* sample in a cohort. The field is determined by the provided [[StructField]]. If the field does
* not exist in the genotype struct, an analysis error will be thrown.
*
* The return type is a map of sampleId -> summary statistics.
* The return type is an array of summary statistics. If sample ids are included in the input,
* they'll be propagated to the results.
*/
case class PerSampleSummaryStatistics(
genotypes: Expression,
Expand All @@ -54,38 +54,53 @@ case class PerSampleSummaryStatistics(

override def children: Seq[Expression] = Seq(genotypes)
override def nullable: Boolean = false
override def dataType: DataType = MapType(StringType, MomentAggState.schema)

override def dataType: DataType =
if (hasSampleIds) {
val fields = VariantSchemas.sampleIdField +: MomentAggState.schema.fields
ArrayType(StructType(fields))
} else {
ArrayType(MomentAggState.schema)
}

override def genotypesExpr: Expression = genotypes
override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.sampleIdField, field)
override def genotypeFieldsRequired: Seq[StructField] = Seq(field)
override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField)
private lazy val hasSampleIds = optionalFieldIndices(0) != -1

override def createAggregationBuffer(): ArrayBuffer[SampleSummaryStatsState] = {
mutable.ArrayBuffer[SampleSummaryStatsState]()
}

override def eval(buffer: ArrayBuffer[SampleSummaryStatsState]): Any = {
val keys = new GenericArrayData(buffer.map(s => UTF8String.fromString(s.sampleId)))
val values = new GenericArrayData(buffer.map(s => s.momentAggState.toInternalRow))
new ArrayBasedMapData(keys, values)
if (!hasSampleIds) {
new GenericArrayData(buffer.map(s => s.momentAggState.toInternalRow))
} else {
new GenericArrayData(buffer.map { s =>
val outputRow = new GenericInternalRow(MomentAggState.schema.length + 1)
outputRow.update(0, UTF8String.fromString(s.sampleId))
s.momentAggState.toInternalRow(outputRow, offset = 1)
})
}
}

private lazy val updateStateFn: (MomentAggState, InternalRow) => Unit = {
field.dataType match {
case FloatType =>
(state, genotype) => {
state.update(genotype.getFloat(genotypeFieldIndices(1)))
state.update(genotype.getFloat(genotypeFieldIndices(0)))
}
case DoubleType =>
(state, genotype) => {
state.update(genotype.getDouble(genotypeFieldIndices(1)))
state.update(genotype.getDouble(genotypeFieldIndices(0)))
}
case IntegerType =>
(state, genotype) => {
state.update(genotype.getInt(genotypeFieldIndices(1)))
state.update(genotype.getInt(genotypeFieldIndices(0)))
}
case LongType =>
(state, genotype) => {
state.update(genotype.getLong(genotypeFieldIndices(1)))
state.update(genotype.getLong(genotypeFieldIndices(0)))
}
}
}
Expand All @@ -100,14 +115,18 @@ case class PerSampleSummaryStatistics(

// Make sure the buffer has an entry for this sample
if (i >= buffer.size) {
val sampleId = genotypesArray
.getStruct(buffer.size, genotypeStructSize)
.getString(genotypeFieldIndices.head)
val sampleId = if (hasSampleIds) {
genotypesArray
.getStruct(buffer.size, genotypeStructSize)
.getString(optionalFieldIndices(0))
} else {
null
}
buffer.append(SampleSummaryStatsState(sampleId, MomentAggState()))
}

val struct = genotypesArray.getStruct(i, genotypeStructSize)
if (!struct.isNullAt(genotypeFieldIndices(1))) {
if (!struct.isNullAt(genotypeFieldIndices(0))) {
updateStateFn(buffer(i).momentAggState, genotypesArray.getStruct(i, genotypeStructSize))
}
i += 1
Expand All @@ -130,7 +149,9 @@ case class PerSampleSummaryStatistics(
)
var i = 0
while (i < buffer.size) {
require(buffer(i).sampleId == input(i).sampleId, s"Samples did not match at position $i")
require(
buffer(i).sampleId == input(i).sampleId,
s"Samples did not match at position $i (${buffer(i).sampleId}, ${input(i).sampleId})")
buffer(i).momentAggState =
MomentAggState.merge(buffer(i).momentAggState, input(i).momentAggState)
i += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ import java.nio.ByteBuffer
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import io.projectglow.common.{GlowLogging, VariantSchemas}
import io.projectglow.sql.util.ExpectsGenotypeFields
import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

import io.projectglow.common.{GlowLogging, VariantSchemas}
import io.projectglow.sql.util.ExpectsGenotypeFields

/**
* Computes summary statistics per-sample in a genomic cohort. These statistics include the call
* rate and the number of different types of variants.
*
* The return type is a map of sampleId -> summary statistics.
* The return type is an array of summary statistics. If sample ids are included in the input
* schema, they'll be propagated to the results.
*/
case class CallSummaryStats(
genotypes: Expression,
Expand All @@ -52,12 +52,16 @@ case class CallSummaryStats(
override def genotypesExpr: Expression = genotypes

override def genotypeFieldsRequired: Seq[StructField] = {
Seq(VariantSchemas.callsField, VariantSchemas.sampleIdField)
Seq(VariantSchemas.callsField)
}

override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField)
private lazy val hasSampleIds = optionalFieldIndices(0) != -1
override def children: Seq[Expression] = Seq(genotypes, refAllele, altAlleles)
override def nullable: Boolean = false

override def dataType: DataType = MapType(StringType, SampleCallStats.outputSchema)
override def dataType: DataType =
ArrayType(SampleCallStats.outputSchema(hasSampleIds))

override def checkInputDataTypes(): TypeCheckResult = {
if (super.checkInputDataTypes().isFailure) {
Expand Down Expand Up @@ -86,9 +90,7 @@ case class CallSummaryStats(
}

override def eval(buffer: ArrayBuffer[SampleCallStats]): Any = {
val keys = new GenericArrayData(buffer.map(s => UTF8String.fromString(s.sampleId)))
val values = new GenericArrayData(buffer.map(_.toInternalRow))
new ArrayBasedMapData(keys, values)
new GenericArrayData(buffer.map(_.toInternalRow(hasSampleIds)))
}

override def update(
Expand All @@ -114,9 +116,12 @@ case class CallSummaryStats(

// Make sure the buffer has an entry for this sample
if (j >= buffer.size) {
val sampleId = struct
.getUTF8String(genotypeFieldIndices(1))
buffer.append(SampleCallStats(sampleId.toString))
val sampleId = if (hasSampleIds) {
struct.getUTF8String(optionalFieldIndices(0))
} else {
null
}
buffer.append(SampleCallStats(if (sampleId != null) sampleId.toString else null))
}

val stats = buffer(j)
Expand Down Expand Up @@ -212,7 +217,7 @@ case class CallSummaryStats(
}

case class SampleCallStats(
var sampleId: String,
var sampleId: String = null,
var nCalled: Long = 0,
var nUncalled: Long = 0,
var nHomRef: Long = 0,
Expand All @@ -224,12 +229,8 @@ case class SampleCallStats(
var nTransition: Long = 0,
var nSpanningDeletion: Long = 0) {

def this() = {
this(null)
}

def toInternalRow: InternalRow =
new GenericInternalRow(
def toInternalRow(includeSampleId: Boolean): InternalRow = {
val valueArr =
Array[Any](
nCalled.toDouble / (nCalled + nUncalled),
nCalled,
Expand All @@ -247,7 +248,13 @@ case class SampleCallStats(
nInsertion.toDouble / nDeletion,
nHet.toDouble / nHomVar
)
)

if (includeSampleId) {
new GenericInternalRow(Array[Any](UTF8String.fromString(sampleId)) ++ valueArr)
} else {
new GenericInternalRow(valueArr)
}
}
}

object SampleCallStats extends GlowLogging {
Expand All @@ -267,7 +274,8 @@ object SampleCallStats extends GlowLogging {
out
}

val outputSchema = StructType(
private[projectglow] def outputSchema(includeSampleId: Boolean): StructType = StructType(
(if (includeSampleId) Some(VariantSchemas.sampleIdField) else None).toSeq ++
Seq(
StructField("callRate", DoubleType),
StructField("nCalled", LongType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,27 @@ trait ExpectsGenotypeFields extends Expression {

protected def genotypeFieldsRequired: Seq[StructField]

protected def optionalGenotypeFields: Seq[StructField] = Seq.empty

private lazy val gStruct = genotypesExpr
.dataType
.asInstanceOf[ArrayType]
.elementType
.asInstanceOf[StructType]

protected lazy val genotypeFieldIndices: Seq[Int] = {
val gStruct = genotypesExpr
.dataType
.asInstanceOf[ArrayType]
.elementType
.asInstanceOf[StructType]
genotypeFieldsRequired.map { f =>
gStruct.indexWhere(SQLUtils.structFieldsEqualExceptNullability(f, _))
}
}

protected lazy val optionalFieldIndices: Seq[Int] = {
optionalGenotypeFields.map { f =>
gStruct.indexWhere(SQLUtils.structFieldsEqualExceptNullability(f, _))
}
}

protected lazy val genotypeStructSize: Int = {
val gStruct = genotypesExpr
.dataType
.asInstanceOf[ArrayType]
.elementType
.asInstanceOf[StructType]
gStruct.length
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ private[projectglow] object VariantNormalizer extends GlowLogging {
* normalizes a single VariantContext by checking some conditions and then calling realignAlleles
*
* @param vc
* @param refGenomePathString
* @return: normalized VariantContext
* @param refGenomeDataSource
* @return normalized VariantContext
*/
private def normalizeVC(
vc: VariantContext,
Expand Down
Loading