diff --git a/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala b/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala index dc169f3ec..04978a273 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala @@ -60,15 +60,19 @@ case class MomentAggState( def update(element: Int): Unit = update(element.toDouble) def update(element: Float): Unit = update(element.toDouble) + /** + * Writes the mean, stdev, min, and max into the input row beginning at the provided offset. + */ + def toInternalRow(row: InternalRow, offset: Int = 0): InternalRow = { + row.update(offset, if (count > 0) mean else null) + row.update(offset + 1, if (count > 0) Math.sqrt(m2 / (count - 1)) else null) + row.update(offset + 2, if (count > 0) min else null) + row.update(offset + 3, if (count > 0) max else null) + row + } + def toInternalRow: InternalRow = { - new GenericInternalRow( - Array( - if (count > 0) mean else null, - if (count > 0) Math.sqrt(m2 / (count - 1)) else null, - if (count > 0) min else null, - if (count > 0) max else null - ) - ) + toInternalRow(new GenericInternalRow(4)) } } diff --git a/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala b/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala index 9bcfdda63..56d134873 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala @@ -23,12 +23,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import io.projectglow.common.{GlowLogging, VariantSchemas} import io.projectglow.sql.util.ExpectsGenotypeFields @@ -41,7 +40,8 @@ case class SampleSummaryStatsState(var sampleId: String, var momentAggState: Mom * sample in a cohort. The field is determined by the provided [[StructField]]. If the field does * not exist in the genotype struct, an analysis error will be thrown. * - * The return type is a map of sampleId -> summary statistics. + * The return type is an array of summary statistics. If sample ids are included in the input, + * they'll be propagated to the results. */ case class PerSampleSummaryStatistics( genotypes: Expression, @@ -54,38 +54,53 @@ case class PerSampleSummaryStatistics( override def children: Seq[Expression] = Seq(genotypes) override def nullable: Boolean = false - override def dataType: DataType = MapType(StringType, MomentAggState.schema) + + override def dataType: DataType = + if (hasSampleIds) { + val fields = VariantSchemas.sampleIdField +: MomentAggState.schema.fields + ArrayType(StructType(fields)) + } else { + ArrayType(MomentAggState.schema) + } override def genotypesExpr: Expression = genotypes - override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.sampleIdField, field) + override def genotypeFieldsRequired: Seq[StructField] = Seq(field) + override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField) + private lazy val hasSampleIds = optionalFieldIndices(0) != -1 override def createAggregationBuffer(): ArrayBuffer[SampleSummaryStatsState] = { mutable.ArrayBuffer[SampleSummaryStatsState]() } override def eval(buffer: ArrayBuffer[SampleSummaryStatsState]): Any = { - val keys = new GenericArrayData(buffer.map(s => UTF8String.fromString(s.sampleId))) - val values = new GenericArrayData(buffer.map(s => s.momentAggState.toInternalRow)) - new ArrayBasedMapData(keys, values) + if (!hasSampleIds) { + new GenericArrayData(buffer.map(s => s.momentAggState.toInternalRow)) + } else { + new GenericArrayData(buffer.map { s => + val outputRow = new GenericInternalRow(MomentAggState.schema.length + 1) + outputRow.update(0, UTF8String.fromString(s.sampleId)) + s.momentAggState.toInternalRow(outputRow, offset = 1) + }) + } } private lazy val updateStateFn: (MomentAggState, InternalRow) => Unit = { field.dataType match { case FloatType => (state, genotype) => { - state.update(genotype.getFloat(genotypeFieldIndices(1))) + state.update(genotype.getFloat(genotypeFieldIndices(0))) } case DoubleType => (state, genotype) => { - state.update(genotype.getDouble(genotypeFieldIndices(1))) + state.update(genotype.getDouble(genotypeFieldIndices(0))) } case IntegerType => (state, genotype) => { - state.update(genotype.getInt(genotypeFieldIndices(1))) + state.update(genotype.getInt(genotypeFieldIndices(0))) } case LongType => (state, genotype) => { - state.update(genotype.getLong(genotypeFieldIndices(1))) + state.update(genotype.getLong(genotypeFieldIndices(0))) } } } @@ -100,14 +115,18 @@ case class PerSampleSummaryStatistics( // Make sure the buffer has an entry for this sample if (i >= buffer.size) { - val sampleId = genotypesArray - .getStruct(buffer.size, genotypeStructSize) - .getString(genotypeFieldIndices.head) + val sampleId = if (hasSampleIds) { + genotypesArray + .getStruct(buffer.size, genotypeStructSize) + .getString(optionalFieldIndices(0)) + } else { + null + } buffer.append(SampleSummaryStatsState(sampleId, MomentAggState())) } val struct = genotypesArray.getStruct(i, genotypeStructSize) - if (!struct.isNullAt(genotypeFieldIndices(1))) { + if (!struct.isNullAt(genotypeFieldIndices(0))) { updateStateFn(buffer(i).momentAggState, genotypesArray.getStruct(i, genotypeStructSize)) } i += 1 @@ -130,7 +149,9 @@ case class PerSampleSummaryStatistics( ) var i = 0 while (i < buffer.size) { - require(buffer(i).sampleId == input(i).sampleId, s"Samples did not match at position $i") + require( + buffer(i).sampleId == input(i).sampleId, + s"Samples did not match at position $i (${buffer(i).sampleId}, ${input(i).sampleId})") buffer(i).momentAggState = MomentAggState.merge(buffer(i).momentAggState, input(i).momentAggState) i += 1 diff --git a/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala b/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala index 457b81fbb..c52a83113 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala @@ -21,23 +21,23 @@ import java.nio.ByteBuffer import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import io.projectglow.common.{GlowLogging, VariantSchemas} +import io.projectglow.sql.util.ExpectsGenotypeFields import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import io.projectglow.common.{GlowLogging, VariantSchemas} -import io.projectglow.sql.util.ExpectsGenotypeFields - /** * Computes summary statistics per-sample in a genomic cohort. These statistics include the call * rate and the number of different types of variants. * - * The return type is a map of sampleId -> summary statistics. + * The return type is an array of summary statistics. If sample ids are included in the input + * schema, they'll be propagated to the results. */ case class CallSummaryStats( genotypes: Expression, @@ -52,12 +52,16 @@ case class CallSummaryStats( override def genotypesExpr: Expression = genotypes override def genotypeFieldsRequired: Seq[StructField] = { - Seq(VariantSchemas.callsField, VariantSchemas.sampleIdField) + Seq(VariantSchemas.callsField) } + + override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField) + private lazy val hasSampleIds = optionalFieldIndices(0) != -1 override def children: Seq[Expression] = Seq(genotypes, refAllele, altAlleles) override def nullable: Boolean = false - override def dataType: DataType = MapType(StringType, SampleCallStats.outputSchema) + override def dataType: DataType = + ArrayType(SampleCallStats.outputSchema(hasSampleIds)) override def checkInputDataTypes(): TypeCheckResult = { if (super.checkInputDataTypes().isFailure) { @@ -86,9 +90,7 @@ case class CallSummaryStats( } override def eval(buffer: ArrayBuffer[SampleCallStats]): Any = { - val keys = new GenericArrayData(buffer.map(s => UTF8String.fromString(s.sampleId))) - val values = new GenericArrayData(buffer.map(_.toInternalRow)) - new ArrayBasedMapData(keys, values) + new GenericArrayData(buffer.map(_.toInternalRow(hasSampleIds))) } override def update( @@ -114,9 +116,12 @@ case class CallSummaryStats( // Make sure the buffer has an entry for this sample if (j >= buffer.size) { - val sampleId = struct - .getUTF8String(genotypeFieldIndices(1)) - buffer.append(SampleCallStats(sampleId.toString)) + val sampleId = if (hasSampleIds) { + struct.getUTF8String(optionalFieldIndices(0)) + } else { + null + } + buffer.append(SampleCallStats(if (sampleId != null) sampleId.toString else null)) } val stats = buffer(j) @@ -212,7 +217,7 @@ case class CallSummaryStats( } case class SampleCallStats( - var sampleId: String, + var sampleId: String = null, var nCalled: Long = 0, var nUncalled: Long = 0, var nHomRef: Long = 0, @@ -224,12 +229,8 @@ case class SampleCallStats( var nTransition: Long = 0, var nSpanningDeletion: Long = 0) { - def this() = { - this(null) - } - - def toInternalRow: InternalRow = - new GenericInternalRow( + def toInternalRow(includeSampleId: Boolean): InternalRow = { + val valueArr = Array[Any]( nCalled.toDouble / (nCalled + nUncalled), nCalled, @@ -247,7 +248,13 @@ case class SampleCallStats( nInsertion.toDouble / nDeletion, nHet.toDouble / nHomVar ) - ) + + if (includeSampleId) { + new GenericInternalRow(Array[Any](UTF8String.fromString(sampleId)) ++ valueArr) + } else { + new GenericInternalRow(valueArr) + } + } } object SampleCallStats extends GlowLogging { @@ -267,7 +274,8 @@ object SampleCallStats extends GlowLogging { out } - val outputSchema = StructType( + private[projectglow] def outputSchema(includeSampleId: Boolean): StructType = StructType( + (if (includeSampleId) Some(VariantSchemas.sampleIdField) else None).toSeq ++ Seq( StructField("callRate", DoubleType), StructField("nCalled", LongType), diff --git a/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala b/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala index 5db85c49a..f78397aba 100644 --- a/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala +++ b/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala @@ -30,23 +30,27 @@ trait ExpectsGenotypeFields extends Expression { protected def genotypeFieldsRequired: Seq[StructField] + protected def optionalGenotypeFields: Seq[StructField] = Seq.empty + + private lazy val gStruct = genotypesExpr + .dataType + .asInstanceOf[ArrayType] + .elementType + .asInstanceOf[StructType] + protected lazy val genotypeFieldIndices: Seq[Int] = { - val gStruct = genotypesExpr - .dataType - .asInstanceOf[ArrayType] - .elementType - .asInstanceOf[StructType] genotypeFieldsRequired.map { f => gStruct.indexWhere(SQLUtils.structFieldsEqualExceptNullability(f, _)) } } + protected lazy val optionalFieldIndices: Seq[Int] = { + optionalGenotypeFields.map { f => + gStruct.indexWhere(SQLUtils.structFieldsEqualExceptNullability(f, _)) + } + } + protected lazy val genotypeStructSize: Int = { - val gStruct = genotypesExpr - .dataType - .asInstanceOf[ArrayType] - .elementType - .asInstanceOf[StructType] gStruct.length } diff --git a/core/src/main/scala/io/projectglow/transformers/normalizevariants/VariantNormalizer.scala b/core/src/main/scala/io/projectglow/transformers/normalizevariants/VariantNormalizer.scala index a11159a11..cc14e3655 100644 --- a/core/src/main/scala/io/projectglow/transformers/normalizevariants/VariantNormalizer.scala +++ b/core/src/main/scala/io/projectglow/transformers/normalizevariants/VariantNormalizer.scala @@ -178,8 +178,8 @@ private[projectglow] object VariantNormalizer extends GlowLogging { * normalizes a single VariantContext by checking some conditions and then calling realignAlleles * * @param vc - * @param refGenomePathString - * @return: normalized VariantContext + * @param refGenomeDataSource + * @return normalized VariantContext */ private def normalizeVC( vc: VariantContext, diff --git a/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala b/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala index 5d34ceb4f..396ccecf7 100644 --- a/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala +++ b/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala @@ -16,39 +16,26 @@ package io.projectglow.tertiary -import org.apache.spark.sql.{DataFrame, Row} - -import io.projectglow.common.VCFRow -import io.projectglow.common.VCFRow -import io.projectglow.sql.GlowBaseTest +import io.projectglow.common.{VCFRow, VariantSchemas} import io.projectglow.sql.GlowBaseTest +import org.apache.spark.sql.{DataFrame, Row} class SampleQcExprsSuite extends GlowBaseTest { lazy val testVcf = s"$testDataHome/1000G.phase3.broad.withGenotypes.chr20.10100000.vcf" lazy val na12878 = s"$testDataHome/CEUTrio.HiSeq.WGS.b37.NA12878.20.21.vcf" - lazy private val sess = { - // Set small partition size so that `merge` code path is implemented - spark.conf.set("spark.sql.files.maxPartitionBytes", 512) - spark - } + lazy private val sess = spark - test("high level test") { + gridTest("sample_call_summary_stats high level test")(Seq(true, false)) { sampleIds => import sess.implicits._ - val df = spark - .read - .format("vcf") - .option("includeSampleIds", true) - .load(na12878) + val df = readVcf(na12878, sampleIds) val stats = df .selectExpr("sample_call_summary_stats(genotypes, referenceAllele, alternateAlleles) as qc") - .selectExpr("explode(qc) as (sampleId, sqc)") - .selectExpr("sampleId", "expand_struct(sqc)") + .selectExpr("expand_struct(qc[0])") .as[TestSampleCallStats] .head // Golden value is from Hail 0.2.12 val expected = TestSampleCallStats( - "NA12878", 1, 1075, 0, @@ -87,19 +74,20 @@ class SampleQcExprsSuite extends GlowBaseTest { spark.createDataFrame(data) } - private def readVcf(path: String): DataFrame = { + private def readVcf(path: String, includeSampleIds: Boolean = true): DataFrame = { spark .read .format("vcf") - .option("includeSampleIds", true) + .option("includeSampleIds", includeSampleIds) .load(path) + .repartition(4) } private def toCallStats(df: DataFrame): Seq[TestSampleCallStats] = { import sess.implicits._ df.selectExpr("sample_call_summary_stats(genotypes, referenceAllele, alternateAlleles) as qc") - .selectExpr("explode(qc) as (sampleId, sqc)") - .selectExpr("sampleId", "expand_struct(sqc)") + .selectExpr("explode(qc) as sqc") + .selectExpr("expand_struct(sqc)") .as[TestSampleCallStats] .collect() } @@ -117,12 +105,12 @@ class SampleQcExprsSuite extends GlowBaseTest { assert(stats.isEmpty) // No error expected } - test("dp stats") { + gridTest("dp stats")(Seq(true, false)) { sampleIds => import sess.implicits._ - val stats = readVcf(na12878) + val stats = readVcf(na12878, includeSampleIds = sampleIds) .selectExpr("sample_dp_summary_stats(genotypes) as stats") - .selectExpr("explode(stats) as (sampleId, stats)") - .selectExpr("sampleId", "expand_struct(stats)") + .selectExpr("explode(stats) as dp_stats") + .selectExpr("expand_struct(dp_stats)") .as[ArraySummaryStats] .head assert(stats.min.get ~== 1 relTol 0.2) @@ -135,7 +123,7 @@ class SampleQcExprsSuite extends GlowBaseTest { import sess.implicits._ val stats = makeDf(Seq(Some(1), None, Some(3))) .selectExpr("sample_dp_summary_stats(genotypes) as stats") - .selectExpr("explode(stats) as (sampleId, stats)") + .selectExpr("explode(stats) as stats") .selectExpr("expand_struct(stats)") .as[ArraySummaryStats] .head() @@ -148,8 +136,8 @@ class SampleQcExprsSuite extends GlowBaseTest { import sess.implicits._ val stats = readVcf(na12878) .selectExpr("sample_gq_summary_stats(genotypes) as stats") - .selectExpr("explode(stats) as (sampleId, stats)") - .selectExpr("sampleId", "expand_struct(stats)") + .selectExpr("explode(stats) as stats") + .selectExpr("expand_struct(stats)") .as[ArraySummaryStats] .head assert(stats.min.get ~== 3 relTol 0.2) @@ -157,10 +145,28 @@ class SampleQcExprsSuite extends GlowBaseTest { assert(stats.mean.get ~== 89.2 relTol 0.2) assert(stats.stdDev.get ~== 23.2 relTol 0.2) } + + private val expressionsToTest = Seq( + "expand_struct(sample_call_summary_stats(genotypes, referenceAllele, alternateAlleles)[0])", + "expand_struct(sample_gq_summary_stats(genotypes)[0])", + "expand_struct(sample_dp_summary_stats(genotypes)[0])" + ) + private val testCases = expressionsToTest + .flatMap(expr => Seq((expr, true), (expr, false))) + gridTest("sample ids are propagated if included")(testCases) { + case (expr, sampleIds) => + import sess.implicits._ + val df = readVcf(na12878, sampleIds) + .selectExpr(expr) + val outputSchema = df.schema + assert(outputSchema.exists(_.name == VariantSchemas.sampleIdField.name) == sampleIds) + if (sampleIds) { + assert(df.select("sampleId").as[String].head == "NA12878") + } + } } case class TestSampleCallStats( - sampleId: String, callRate: Double, nCalled: Long, nUncalled: Long,