From 1ff61df0d1eb28cc062403df12a891403f10ee06 Mon Sep 17 00:00:00 2001 From: Henry D Date: Tue, 15 Oct 2019 10:52:46 -0700 Subject: [PATCH 1/3] some work Make sample qc return an array instead of a map Signed-off-by: Henry D --- .../sql/expressions/MomentAggState.scala | 17 ++-- .../PerSampleSummaryStatistics.scala | 54 +++++++---- .../expressions/SampleCallSummaryStats.scala | 95 ++++++++++++------- .../sql/util/ExpectsGenotypeFields.scala | 24 +++-- .../tertiary/SampleQcExprsSuite.scala | 47 +++++---- 5 files changed, 150 insertions(+), 87 deletions(-) diff --git a/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala b/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala index dc169f3ec..5bae76c01 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala @@ -60,15 +60,16 @@ case class MomentAggState( def update(element: Int): Unit = update(element.toDouble) def update(element: Float): Unit = update(element.toDouble) + def toInternalRow(row: InternalRow): InternalRow = { + row.update(0, if (count > 0) mean else null) + row.update(1, if (count > 0) Math.sqrt(m2 / (count - 1))) + row.update(2, if (count > 0) min else null) + row.update(3, if (count > 0) max else null) + row + } + def toInternalRow: InternalRow = { - new GenericInternalRow( - Array( - if (count > 0) mean else null, - if (count > 0) Math.sqrt(m2 / (count - 1)) else null, - if (count > 0) min else null, - if (count > 0) max else null - ) - ) + toInternalRow(new GenericInternalRow(4)) } } diff --git a/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala b/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala index 9bcfdda63..ea854caf7 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala @@ -23,12 +23,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - import io.projectglow.common.{GlowLogging, VariantSchemas} import io.projectglow.sql.util.ExpectsGenotypeFields @@ -41,7 +40,8 @@ case class SampleSummaryStatsState(var sampleId: String, var momentAggState: Mom * sample in a cohort. The field is determined by the provided [[StructField]]. If the field does * not exist in the genotype struct, an analysis error will be thrown. * - * The return type is a map of sampleId -> summary statistics. + * The return type is an array of summary statistics. If sample ids are included in the input, + * they'll be propagated to the results. */ case class PerSampleSummaryStatistics( genotypes: Expression, @@ -54,38 +54,52 @@ case class PerSampleSummaryStatistics( override def children: Seq[Expression] = Seq(genotypes) override def nullable: Boolean = false - override def dataType: DataType = MapType(StringType, MomentAggState.schema) + + override def dataType: DataType = + if (optionalFieldIndices(0) != -1) { + ArrayType(MomentAggState.schema.add(VariantSchemas.sampleIdField)) + } else { + ArrayType(MomentAggState.schema) + } override def genotypesExpr: Expression = genotypes - override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.sampleIdField, field) + override def genotypeFieldsRequired: Seq[StructField] = Seq(field) + override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField) override def createAggregationBuffer(): ArrayBuffer[SampleSummaryStatsState] = { mutable.ArrayBuffer[SampleSummaryStatsState]() } override def eval(buffer: ArrayBuffer[SampleSummaryStatsState]): Any = { - val keys = new GenericArrayData(buffer.map(s => UTF8String.fromString(s.sampleId))) - val values = new GenericArrayData(buffer.map(s => s.momentAggState.toInternalRow)) - new ArrayBasedMapData(keys, values) + if (optionalFieldIndices(0) == -1) { // no sample ids + new GenericArrayData(buffer.map(s => s.momentAggState.toInternalRow)) + } else { + new GenericArrayData(buffer.map { s => + val outputRow = new GenericInternalRow(MomentAggState.schema.length + 1) + s.momentAggState.toInternalRow(outputRow) + outputRow.update(MomentAggState.schema.length, UTF8String.fromString(s.sampleId)) + outputRow + }) + } } private lazy val updateStateFn: (MomentAggState, InternalRow) => Unit = { field.dataType match { case FloatType => (state, genotype) => { - state.update(genotype.getFloat(genotypeFieldIndices(1))) + state.update(genotype.getFloat(genotypeFieldIndices(0))) } case DoubleType => (state, genotype) => { - state.update(genotype.getDouble(genotypeFieldIndices(1))) + state.update(genotype.getDouble(genotypeFieldIndices(0))) } case IntegerType => (state, genotype) => { - state.update(genotype.getInt(genotypeFieldIndices(1))) + state.update(genotype.getInt(genotypeFieldIndices(0))) } case LongType => (state, genotype) => { - state.update(genotype.getLong(genotypeFieldIndices(1))) + state.update(genotype.getLong(genotypeFieldIndices(0))) } } } @@ -100,14 +114,18 @@ case class PerSampleSummaryStatistics( // Make sure the buffer has an entry for this sample if (i >= buffer.size) { - val sampleId = genotypesArray - .getStruct(buffer.size, genotypeStructSize) - .getString(genotypeFieldIndices.head) + val sampleId = if (optionalFieldIndices(0) != -1) { + genotypesArray + .getStruct(buffer.size, genotypeStructSize) + .getString(optionalFieldIndices(0)) + } else { + null + } buffer.append(SampleSummaryStatsState(sampleId, MomentAggState())) } val struct = genotypesArray.getStruct(i, genotypeStructSize) - if (!struct.isNullAt(genotypeFieldIndices(1))) { + if (!struct.isNullAt(genotypeFieldIndices(0))) { updateStateFn(buffer(i).momentAggState, genotypesArray.getStruct(i, genotypeStructSize)) } i += 1 @@ -130,7 +148,9 @@ case class PerSampleSummaryStatistics( ) var i = 0 while (i < buffer.size) { - require(buffer(i).sampleId == input(i).sampleId, s"Samples did not match at position $i") + require( + buffer(i).sampleId == input(i).sampleId, + s"Samples did not match at position $i (${buffer(i).sampleId}, ${input(i).sampleId})") buffer(i).momentAggState = MomentAggState.merge(buffer(i).momentAggState, input(i).momentAggState) i += 1 diff --git a/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala b/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala index 457b81fbb..e08f5e3f3 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala @@ -21,23 +21,23 @@ import java.nio.ByteBuffer import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import io.projectglow.common.{GlowLogging, VariantSchemas} +import io.projectglow.sql.util.ExpectsGenotypeFields import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import io.projectglow.common.{GlowLogging, VariantSchemas} -import io.projectglow.sql.util.ExpectsGenotypeFields - /** * Computes summary statistics per-sample in a genomic cohort. These statistics include the call * rate and the number of different types of variants. * - * The return type is a map of sampleId -> summary statistics. + * The return type is an array of summary statistics. If sample ids are included in the input + * schema, they'll be propagated to the results. */ case class CallSummaryStats( genotypes: Expression, @@ -52,12 +52,15 @@ case class CallSummaryStats( override def genotypesExpr: Expression = genotypes override def genotypeFieldsRequired: Seq[StructField] = { - Seq(VariantSchemas.callsField, VariantSchemas.sampleIdField) + Seq(VariantSchemas.callsField) } + + override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField) override def children: Seq[Expression] = Seq(genotypes, refAllele, altAlleles) override def nullable: Boolean = false - override def dataType: DataType = MapType(StringType, SampleCallStats.outputSchema) + override def dataType: DataType = + ArrayType(SampleCallStats.outputSchema(optionalFieldIndices(0) != -1)) override def checkInputDataTypes(): TypeCheckResult = { if (super.checkInputDataTypes().isFailure) { @@ -86,9 +89,7 @@ case class CallSummaryStats( } override def eval(buffer: ArrayBuffer[SampleCallStats]): Any = { - val keys = new GenericArrayData(buffer.map(s => UTF8String.fromString(s.sampleId))) - val values = new GenericArrayData(buffer.map(_.toInternalRow)) - new ArrayBasedMapData(keys, values) + new GenericArrayData(buffer.map(_.toInternalRow(optionalFieldIndices(0) != -1))) } override def update( @@ -114,8 +115,12 @@ case class CallSummaryStats( // Make sure the buffer has an entry for this sample if (j >= buffer.size) { - val sampleId = struct - .getUTF8String(genotypeFieldIndices(1)) + val sampleId = if (optionalFieldIndices(0) != -1) { + struct + .getUTF8String(optionalFieldIndices(0)) + } else { + null + } buffer.append(SampleCallStats(sampleId.toString)) } @@ -212,7 +217,7 @@ case class CallSummaryStats( } case class SampleCallStats( - var sampleId: String, + var sampleId: String = null, var nCalled: Long = 0, var nUncalled: Long = 0, var nHomRef: Long = 0, @@ -224,29 +229,46 @@ case class SampleCallStats( var nTransition: Long = 0, var nSpanningDeletion: Long = 0) { - def this() = { - this(null) - } - - def toInternalRow: InternalRow = + def toInternalRow(includeSampleId: Boolean): InternalRow = new GenericInternalRow( - Array[Any]( - nCalled.toDouble / (nCalled + nUncalled), - nCalled, - nUncalled, - nHomRef, - nHet, - nHomVar, - nTransition + nTransversion, - nInsertion, - nDeletion, - nTransition, - nTransversion, - nSpanningDeletion, - nTransition.toDouble / nTransversion, - nInsertion.toDouble / nDeletion, - nHet.toDouble / nHomVar - ) + if (includeSampleId) { + Array[Any]( + UTF8String.fromString(sampleId), + nCalled.toDouble / (nCalled + nUncalled), + nCalled, + nUncalled, + nHomRef, + nHet, + nHomVar, + nTransition + nTransversion, + nInsertion, + nDeletion, + nTransition, + nTransversion, + nSpanningDeletion, + nTransition.toDouble / nTransversion, + nInsertion.toDouble / nDeletion, + nHet.toDouble / nHomVar + ) + } else { + Array[Any]( + nCalled.toDouble / (nCalled + nUncalled), + nCalled, + nUncalled, + nHomRef, + nHet, + nHomVar, + nTransition + nTransversion, + nInsertion, + nDeletion, + nTransition, + nTransversion, + nSpanningDeletion, + nTransition.toDouble / nTransversion, + nInsertion.toDouble / nDeletion, + nHet.toDouble / nHomVar + ) + } ) } @@ -267,7 +289,8 @@ object SampleCallStats extends GlowLogging { out } - val outputSchema = StructType( + private[projectglow] def outputSchema(includeSampleId: Boolean): StructType = StructType( + (if (includeSampleId) Some(VariantSchemas.sampleIdField) else None).toSeq ++ Seq( StructField("callRate", DoubleType), StructField("nCalled", LongType), diff --git a/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala b/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala index 5db85c49a..f78397aba 100644 --- a/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala +++ b/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala @@ -30,23 +30,27 @@ trait ExpectsGenotypeFields extends Expression { protected def genotypeFieldsRequired: Seq[StructField] + protected def optionalGenotypeFields: Seq[StructField] = Seq.empty + + private lazy val gStruct = genotypesExpr + .dataType + .asInstanceOf[ArrayType] + .elementType + .asInstanceOf[StructType] + protected lazy val genotypeFieldIndices: Seq[Int] = { - val gStruct = genotypesExpr - .dataType - .asInstanceOf[ArrayType] - .elementType - .asInstanceOf[StructType] genotypeFieldsRequired.map { f => gStruct.indexWhere(SQLUtils.structFieldsEqualExceptNullability(f, _)) } } + protected lazy val optionalFieldIndices: Seq[Int] = { + optionalGenotypeFields.map { f => + gStruct.indexWhere(SQLUtils.structFieldsEqualExceptNullability(f, _)) + } + } + protected lazy val genotypeStructSize: Int = { - val gStruct = genotypesExpr - .dataType - .asInstanceOf[ArrayType] - .elementType - .asInstanceOf[StructType] gStruct.length } diff --git a/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala b/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala index 5d34ceb4f..19c9d2e91 100644 --- a/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala +++ b/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala @@ -17,9 +17,7 @@ package io.projectglow.tertiary import org.apache.spark.sql.{DataFrame, Row} - -import io.projectglow.common.VCFRow -import io.projectglow.common.VCFRow +import io.projectglow.common.{VCFRow, VariantSchemas} import io.projectglow.sql.GlowBaseTest import io.projectglow.sql.GlowBaseTest @@ -41,14 +39,13 @@ class SampleQcExprsSuite extends GlowBaseTest { .load(na12878) val stats = df .selectExpr("sample_call_summary_stats(genotypes, referenceAllele, alternateAlleles) as qc") - .selectExpr("explode(qc) as (sampleId, sqc)") - .selectExpr("sampleId", "expand_struct(sqc)") + .selectExpr("explode(qc) as sqc") + .selectExpr("expand_struct(sqc)") .as[TestSampleCallStats] .head // Golden value is from Hail 0.2.12 val expected = TestSampleCallStats( - "NA12878", 1, 1075, 0, @@ -87,19 +84,19 @@ class SampleQcExprsSuite extends GlowBaseTest { spark.createDataFrame(data) } - private def readVcf(path: String): DataFrame = { + private def readVcf(path: String, includeSampleIds: Boolean = true): DataFrame = { spark .read .format("vcf") - .option("includeSampleIds", true) + .option("includeSampleIds", includeSampleIds) .load(path) } private def toCallStats(df: DataFrame): Seq[TestSampleCallStats] = { import sess.implicits._ df.selectExpr("sample_call_summary_stats(genotypes, referenceAllele, alternateAlleles) as qc") - .selectExpr("explode(qc) as (sampleId, sqc)") - .selectExpr("sampleId", "expand_struct(sqc)") + .selectExpr("explode(qc) as sqc") + .selectExpr("expand_struct(sqc)") .as[TestSampleCallStats] .collect() } @@ -121,8 +118,8 @@ class SampleQcExprsSuite extends GlowBaseTest { import sess.implicits._ val stats = readVcf(na12878) .selectExpr("sample_dp_summary_stats(genotypes) as stats") - .selectExpr("explode(stats) as (sampleId, stats)") - .selectExpr("sampleId", "expand_struct(stats)") + .selectExpr("explode(stats) as dp_stats") + .selectExpr("expand_struct(dp_stats)") .as[ArraySummaryStats] .head assert(stats.min.get ~== 1 relTol 0.2) @@ -135,7 +132,7 @@ class SampleQcExprsSuite extends GlowBaseTest { import sess.implicits._ val stats = makeDf(Seq(Some(1), None, Some(3))) .selectExpr("sample_dp_summary_stats(genotypes) as stats") - .selectExpr("explode(stats) as (sampleId, stats)") + .selectExpr("explode(stats) as stats") .selectExpr("expand_struct(stats)") .as[ArraySummaryStats] .head() @@ -148,8 +145,8 @@ class SampleQcExprsSuite extends GlowBaseTest { import sess.implicits._ val stats = readVcf(na12878) .selectExpr("sample_gq_summary_stats(genotypes) as stats") - .selectExpr("explode(stats) as (sampleId, stats)") - .selectExpr("sampleId", "expand_struct(stats)") + .selectExpr("explode(stats) as stats") + .selectExpr("expand_struct(stats)") .as[ArraySummaryStats] .head assert(stats.min.get ~== 3 relTol 0.2) @@ -157,10 +154,28 @@ class SampleQcExprsSuite extends GlowBaseTest { assert(stats.mean.get ~== 89.2 relTol 0.2) assert(stats.stdDev.get ~== 23.2 relTol 0.2) } + + private val expressionsToTest = Seq( + "expand_struct(sample_call_summary_stats(genotypes, referenceAllele, alternateAlleles)[0])", + "expand_struct(sample_gq_summary_stats(genotypes)[0])", + "expand_struct(sample_dp_summary_stats(genotypes)[0])" + ) + private val testCases = expressionsToTest + .flatMap(expr => Seq((expr, true), (expr, false))) + gridTest("sample ids are propagated if included")(testCases) { + case (expr, sampleIds) => + import sess.implicits._ + val df = readVcf(na12878, sampleIds) + .selectExpr(expr) + val outputSchema = df.schema + assert(outputSchema.exists(_.name == VariantSchemas.sampleIdField.name) == sampleIds) + if (sampleIds) { + assert(df.select("sampleId").as[String].head == "NA12878") + } + } } case class TestSampleCallStats( - sampleId: String, callRate: Double, nCalled: Long, nUncalled: Long, From c70bb46b83d55ca150c89288574c0a99d2b8251b Mon Sep 17 00:00:00 2001 From: Henry D Date: Tue, 15 Oct 2019 14:17:41 -0700 Subject: [PATCH 2/3] fix test Signed-off-by: Henry D --- .../scala/io/projectglow/sql/expressions/MomentAggState.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala b/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala index 5bae76c01..76172998a 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala @@ -62,7 +62,7 @@ case class MomentAggState( def toInternalRow(row: InternalRow): InternalRow = { row.update(0, if (count > 0) mean else null) - row.update(1, if (count > 0) Math.sqrt(m2 / (count - 1))) + row.update(1, if (count > 0) Math.sqrt(m2 / (count - 1)) else null) row.update(2, if (count > 0) min else null) row.update(3, if (count > 0) max else null) row From 779bd74f528798df7a63f88c04b931156d90a67a Mon Sep 17 00:00:00 2001 From: Henry D Date: Wed, 16 Oct 2019 10:31:22 -0700 Subject: [PATCH 3/3] karen's comments Signed-off-by: Henry D --- .../sql/expressions/MomentAggState.scala | 13 +-- .../PerSampleSummaryStatistics.scala | 15 ++-- .../expressions/SampleCallSummaryStats.scala | 79 ++++++++----------- .../normalizevariants/VariantNormalizer.scala | 4 +- .../tertiary/SampleQcExprsSuite.scala | 25 ++---- 5 files changed, 58 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala b/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala index 76172998a..04978a273 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/MomentAggState.scala @@ -60,11 +60,14 @@ case class MomentAggState( def update(element: Int): Unit = update(element.toDouble) def update(element: Float): Unit = update(element.toDouble) - def toInternalRow(row: InternalRow): InternalRow = { - row.update(0, if (count > 0) mean else null) - row.update(1, if (count > 0) Math.sqrt(m2 / (count - 1)) else null) - row.update(2, if (count > 0) min else null) - row.update(3, if (count > 0) max else null) + /** + * Writes the mean, stdev, min, and max into the input row beginning at the provided offset. + */ + def toInternalRow(row: InternalRow, offset: Int = 0): InternalRow = { + row.update(offset, if (count > 0) mean else null) + row.update(offset + 1, if (count > 0) Math.sqrt(m2 / (count - 1)) else null) + row.update(offset + 2, if (count > 0) min else null) + row.update(offset + 3, if (count > 0) max else null) row } diff --git a/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala b/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala index ea854caf7..56d134873 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala @@ -56,8 +56,9 @@ case class PerSampleSummaryStatistics( override def nullable: Boolean = false override def dataType: DataType = - if (optionalFieldIndices(0) != -1) { - ArrayType(MomentAggState.schema.add(VariantSchemas.sampleIdField)) + if (hasSampleIds) { + val fields = VariantSchemas.sampleIdField +: MomentAggState.schema.fields + ArrayType(StructType(fields)) } else { ArrayType(MomentAggState.schema) } @@ -65,20 +66,20 @@ case class PerSampleSummaryStatistics( override def genotypesExpr: Expression = genotypes override def genotypeFieldsRequired: Seq[StructField] = Seq(field) override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField) + private lazy val hasSampleIds = optionalFieldIndices(0) != -1 override def createAggregationBuffer(): ArrayBuffer[SampleSummaryStatsState] = { mutable.ArrayBuffer[SampleSummaryStatsState]() } override def eval(buffer: ArrayBuffer[SampleSummaryStatsState]): Any = { - if (optionalFieldIndices(0) == -1) { // no sample ids + if (!hasSampleIds) { new GenericArrayData(buffer.map(s => s.momentAggState.toInternalRow)) } else { new GenericArrayData(buffer.map { s => val outputRow = new GenericInternalRow(MomentAggState.schema.length + 1) - s.momentAggState.toInternalRow(outputRow) - outputRow.update(MomentAggState.schema.length, UTF8String.fromString(s.sampleId)) - outputRow + outputRow.update(0, UTF8String.fromString(s.sampleId)) + s.momentAggState.toInternalRow(outputRow, offset = 1) }) } } @@ -114,7 +115,7 @@ case class PerSampleSummaryStatistics( // Make sure the buffer has an entry for this sample if (i >= buffer.size) { - val sampleId = if (optionalFieldIndices(0) != -1) { + val sampleId = if (hasSampleIds) { genotypesArray .getStruct(buffer.size, genotypeStructSize) .getString(optionalFieldIndices(0)) diff --git a/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala b/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala index e08f5e3f3..c52a83113 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala @@ -56,11 +56,12 @@ case class CallSummaryStats( } override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField) + private lazy val hasSampleIds = optionalFieldIndices(0) != -1 override def children: Seq[Expression] = Seq(genotypes, refAllele, altAlleles) override def nullable: Boolean = false override def dataType: DataType = - ArrayType(SampleCallStats.outputSchema(optionalFieldIndices(0) != -1)) + ArrayType(SampleCallStats.outputSchema(hasSampleIds)) override def checkInputDataTypes(): TypeCheckResult = { if (super.checkInputDataTypes().isFailure) { @@ -89,7 +90,7 @@ case class CallSummaryStats( } override def eval(buffer: ArrayBuffer[SampleCallStats]): Any = { - new GenericArrayData(buffer.map(_.toInternalRow(optionalFieldIndices(0) != -1))) + new GenericArrayData(buffer.map(_.toInternalRow(hasSampleIds))) } override def update( @@ -115,13 +116,12 @@ case class CallSummaryStats( // Make sure the buffer has an entry for this sample if (j >= buffer.size) { - val sampleId = if (optionalFieldIndices(0) != -1) { - struct - .getUTF8String(optionalFieldIndices(0)) + val sampleId = if (hasSampleIds) { + struct.getUTF8String(optionalFieldIndices(0)) } else { null } - buffer.append(SampleCallStats(sampleId.toString)) + buffer.append(SampleCallStats(if (sampleId != null) sampleId.toString else null)) } val stats = buffer(j) @@ -229,47 +229,32 @@ case class SampleCallStats( var nTransition: Long = 0, var nSpanningDeletion: Long = 0) { - def toInternalRow(includeSampleId: Boolean): InternalRow = - new GenericInternalRow( - if (includeSampleId) { - Array[Any]( - UTF8String.fromString(sampleId), - nCalled.toDouble / (nCalled + nUncalled), - nCalled, - nUncalled, - nHomRef, - nHet, - nHomVar, - nTransition + nTransversion, - nInsertion, - nDeletion, - nTransition, - nTransversion, - nSpanningDeletion, - nTransition.toDouble / nTransversion, - nInsertion.toDouble / nDeletion, - nHet.toDouble / nHomVar - ) - } else { - Array[Any]( - nCalled.toDouble / (nCalled + nUncalled), - nCalled, - nUncalled, - nHomRef, - nHet, - nHomVar, - nTransition + nTransversion, - nInsertion, - nDeletion, - nTransition, - nTransversion, - nSpanningDeletion, - nTransition.toDouble / nTransversion, - nInsertion.toDouble / nDeletion, - nHet.toDouble / nHomVar - ) - } - ) + def toInternalRow(includeSampleId: Boolean): InternalRow = { + val valueArr = + Array[Any]( + nCalled.toDouble / (nCalled + nUncalled), + nCalled, + nUncalled, + nHomRef, + nHet, + nHomVar, + nTransition + nTransversion, + nInsertion, + nDeletion, + nTransition, + nTransversion, + nSpanningDeletion, + nTransition.toDouble / nTransversion, + nInsertion.toDouble / nDeletion, + nHet.toDouble / nHomVar + ) + + if (includeSampleId) { + new GenericInternalRow(Array[Any](UTF8String.fromString(sampleId)) ++ valueArr) + } else { + new GenericInternalRow(valueArr) + } + } } object SampleCallStats extends GlowLogging { diff --git a/core/src/main/scala/io/projectglow/transformers/normalizevariants/VariantNormalizer.scala b/core/src/main/scala/io/projectglow/transformers/normalizevariants/VariantNormalizer.scala index a11159a11..cc14e3655 100644 --- a/core/src/main/scala/io/projectglow/transformers/normalizevariants/VariantNormalizer.scala +++ b/core/src/main/scala/io/projectglow/transformers/normalizevariants/VariantNormalizer.scala @@ -178,8 +178,8 @@ private[projectglow] object VariantNormalizer extends GlowLogging { * normalizes a single VariantContext by checking some conditions and then calling realignAlleles * * @param vc - * @param refGenomePathString - * @return: normalized VariantContext + * @param refGenomeDataSource + * @return normalized VariantContext */ private def normalizeVC( vc: VariantContext, diff --git a/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala b/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala index 19c9d2e91..396ccecf7 100644 --- a/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala +++ b/core/src/test/scala/io/projectglow/tertiary/SampleQcExprsSuite.scala @@ -16,31 +16,21 @@ package io.projectglow.tertiary -import org.apache.spark.sql.{DataFrame, Row} import io.projectglow.common.{VCFRow, VariantSchemas} import io.projectglow.sql.GlowBaseTest -import io.projectglow.sql.GlowBaseTest +import org.apache.spark.sql.{DataFrame, Row} class SampleQcExprsSuite extends GlowBaseTest { lazy val testVcf = s"$testDataHome/1000G.phase3.broad.withGenotypes.chr20.10100000.vcf" lazy val na12878 = s"$testDataHome/CEUTrio.HiSeq.WGS.b37.NA12878.20.21.vcf" - lazy private val sess = { - // Set small partition size so that `merge` code path is implemented - spark.conf.set("spark.sql.files.maxPartitionBytes", 512) - spark - } + lazy private val sess = spark - test("high level test") { + gridTest("sample_call_summary_stats high level test")(Seq(true, false)) { sampleIds => import sess.implicits._ - val df = spark - .read - .format("vcf") - .option("includeSampleIds", true) - .load(na12878) + val df = readVcf(na12878, sampleIds) val stats = df .selectExpr("sample_call_summary_stats(genotypes, referenceAllele, alternateAlleles) as qc") - .selectExpr("explode(qc) as sqc") - .selectExpr("expand_struct(sqc)") + .selectExpr("expand_struct(qc[0])") .as[TestSampleCallStats] .head @@ -90,6 +80,7 @@ class SampleQcExprsSuite extends GlowBaseTest { .format("vcf") .option("includeSampleIds", includeSampleIds) .load(path) + .repartition(4) } private def toCallStats(df: DataFrame): Seq[TestSampleCallStats] = { @@ -114,9 +105,9 @@ class SampleQcExprsSuite extends GlowBaseTest { assert(stats.isEmpty) // No error expected } - test("dp stats") { + gridTest("dp stats")(Seq(true, false)) { sampleIds => import sess.implicits._ - val stats = readVcf(na12878) + val stats = readVcf(na12878, includeSampleIds = sampleIds) .selectExpr("sample_dp_summary_stats(genotypes) as stats") .selectExpr("explode(stats) as dp_stats") .selectExpr("expand_struct(dp_stats)")