diff --git a/build.sbt b/build.sbt index 2c4a851c2..8ebef4d5e 100644 --- a/build.sbt +++ b/build.sbt @@ -48,6 +48,7 @@ def groupByHash(tests: Seq[TestDefinition]): Seq[Tests.Group] = { case (i, groupTests) => val options = ForkOptions() .withRunJVMOptions(Vector("-Dspark.ui.enabled=false", "-Xmx1024m")) + Group(i.toString, groupTests, SubProcess(options)) } .toSeq diff --git a/core/src/main/scala/io/projectglow/sql/SqlExtensionProvider.scala b/core/src/main/scala/io/projectglow/sql/SqlExtensionProvider.scala index c24860606..ab28f6b44 100644 --- a/core/src/main/scala/io/projectglow/sql/SqlExtensionProvider.scala +++ b/core/src/main/scala/io/projectglow/sql/SqlExtensionProvider.scala @@ -32,13 +32,17 @@ import org.yaml.snakeyaml.Yaml import io.projectglow.SparkShim._ import io.projectglow.common.WithUtils import io.projectglow.sql.expressions._ -import io.projectglow.sql.optimizer.{ReplaceExpressionsRule, ResolveAggregateFunctionsRule, ResolveExpandStructRule} +import io.projectglow.sql.optimizer.{ReplaceExpressionsRule, ResolveAggregateFunctionsRule, ResolveExpandStructRule, ResolveGenotypeFields} // TODO(hhd): Spark 3.0 allows extensions to register functions. After Spark 3.0 is released, // we should move all extensions into this class. class GlowSQLExtensions extends (SparkSessionExtensions => Unit) { val resolutionRules: Seq[Rule[LogicalPlan]] = - Seq(ReplaceExpressionsRule, ResolveAggregateFunctionsRule, ResolveExpandStructRule) + Seq( + ReplaceExpressionsRule, + ResolveAggregateFunctionsRule, + ResolveExpandStructRule, + ResolveGenotypeFields) val optimizations: Seq[Rule[LogicalPlan]] = Seq() def apply(extensions: SparkSessionExtensions): Unit = { diff --git a/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala b/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala index 54f45e95c..887a55dd1 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/PerSampleSummaryStatistics.scala @@ -22,16 +22,16 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkEnv -import org.apache.spark.sql.{AnalysisException, SQLUtils} +import org.apache.spark.sql.SQLUtils import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, Literal, UnaryExpression, Unevaluable} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import io.projectglow.common.{GlowLogging, VariantSchemas} -import io.projectglow.sql.util.{ExpectsGenotypeFields, Rewrite} +import io.projectglow.sql.util.{ExpectsGenotypeFields, GenotypeInfo, Rewrite} case class SampleSummaryStatsState(var sampleId: String, var momentAggState: MomentAggState) { def this() = this(null, null) // need 0-arg constructor for serialization @@ -48,6 +48,7 @@ case class SampleSummaryStatsState(var sampleId: String, var momentAggState: Mom case class PerSampleSummaryStatistics( genotypes: Expression, field: Expression, + genotypeInfo: Option[GenotypeInfo] = None, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[mutable.ArrayBuffer[SampleSummaryStatsState]] @@ -66,7 +67,7 @@ case class PerSampleSummaryStatistics( } override def genotypesExpr: Expression = genotypes - override def genotypeFieldsRequired: Seq[StructField] = { + override def requiredGenotypeFields: Seq[StructField] = { if (!field.foldable || field.dataType != StringType) { throw SQLUtils.newAnalysisException("Field must be foldable string") } @@ -81,7 +82,11 @@ case class PerSampleSummaryStatistics( } override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField) - private lazy val hasSampleIds = optionalFieldIndices(0) != -1 + + override def withGenotypeInfo(genotypeInfo: GenotypeInfo): PerSampleSummaryStatistics = { + copy(genotypeInfo = Some(genotypeInfo)) + } + private lazy val hasSampleIds = getGenotypeInfo.optionalFieldIndices(0) != -1 override def createAggregationBuffer(): ArrayBuffer[SampleSummaryStatsState] = { mutable.ArrayBuffer[SampleSummaryStatsState]() @@ -100,22 +105,22 @@ case class PerSampleSummaryStatistics( } private lazy val updateStateFn: (MomentAggState, InternalRow) => Unit = { - genotypeFieldsRequired.head.dataType match { + requiredGenotypeFields.head.dataType match { case FloatType => (state, genotype) => { - state.update(genotype.getFloat(genotypeFieldIndices(0))) + state.update(genotype.getFloat(getGenotypeInfo.requiredFieldIndices(0))) } case DoubleType => (state, genotype) => { - state.update(genotype.getDouble(genotypeFieldIndices(0))) + state.update(genotype.getDouble(getGenotypeInfo.requiredFieldIndices(0))) } case IntegerType => (state, genotype) => { - state.update(genotype.getInt(genotypeFieldIndices(0))) + state.update(genotype.getInt(getGenotypeInfo.requiredFieldIndices(0))) } case LongType => (state, genotype) => { - state.update(genotype.getLong(genotypeFieldIndices(0))) + state.update(genotype.getLong(getGenotypeInfo.requiredFieldIndices(0))) } } } @@ -132,17 +137,17 @@ case class PerSampleSummaryStatistics( if (i >= buffer.size) { val sampleId = if (hasSampleIds) { genotypesArray - .getStruct(buffer.size, genotypeStructSize) - .getString(optionalFieldIndices(0)) + .getStruct(buffer.size, getGenotypeInfo.size) + .getString(getGenotypeInfo.optionalFieldIndices(0)) } else { null } buffer.append(SampleSummaryStatsState(sampleId, MomentAggState())) } - val struct = genotypesArray.getStruct(i, genotypeStructSize) - if (!struct.isNullAt(genotypeFieldIndices(0))) { - updateStateFn(buffer(i).momentAggState, genotypesArray.getStruct(i, genotypeStructSize)) + val struct = genotypesArray.getStruct(i, getGenotypeInfo.size) + if (!struct.isNullAt(getGenotypeInfo.requiredFieldIndices(0))) { + updateStateFn(buffer(i).momentAggState, genotypesArray.getStruct(i, getGenotypeInfo.size)) } i += 1 } diff --git a/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala b/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala index 8880d552a..efefdaa7c 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/SampleCallSummaryStats.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import io.projectglow.common.{GlowLogging, VariantSchemas} -import io.projectglow.sql.util.ExpectsGenotypeFields +import io.projectglow.sql.util.{ExpectsGenotypeFields, GenotypeInfo} import org.apache.spark.SparkEnv import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -43,6 +43,7 @@ case class CallSummaryStats( genotypes: Expression, refAllele: Expression, altAlleles: Expression, + genotypeInfo: Option[GenotypeInfo], mutableAggBufferOffset: Int, inputAggBufferOffset: Int) extends TypedImperativeAggregate[mutable.ArrayBuffer[SampleCallStats]] @@ -50,17 +51,21 @@ case class CallSummaryStats( with GlowLogging { def this(genotypes: Expression, refAllele: Expression, altAlleles: Expression) = { - this(genotypes, refAllele, altAlleles, 0, 0) + this(genotypes, refAllele, altAlleles, None, 0, 0) } override def genotypesExpr: Expression = genotypes - override def genotypeFieldsRequired: Seq[StructField] = { + override def requiredGenotypeFields: Seq[StructField] = { Seq(VariantSchemas.callsField) } override def optionalGenotypeFields: Seq[StructField] = Seq(VariantSchemas.sampleIdField) - private lazy val hasSampleIds = optionalFieldIndices(0) != -1 + + override def withGenotypeInfo(genotypeInfo: GenotypeInfo): CallSummaryStats = { + copy(genotypeInfo = Some(genotypeInfo)) + } + private lazy val hasSampleIds = getGenotypeInfo.optionalFieldIndices(0) != -1 override def children: Seq[Expression] = Seq(genotypes, refAllele, altAlleles) override def nullable: Boolean = false @@ -116,12 +121,12 @@ case class CallSummaryStats( var j = 0 while (j < genotypesArr.numElements()) { val struct = genotypesArr - .getStruct(j, genotypeStructSize) + .getStruct(j, getGenotypeInfo.size) // Make sure the buffer has an entry for this sample if (j >= buffer.size) { val sampleId = if (hasSampleIds) { - struct.getUTF8String(optionalFieldIndices(0)) + struct.getUTF8String(getGenotypeInfo.optionalFieldIndices(0)) } else { null } @@ -130,7 +135,7 @@ case class CallSummaryStats( val stats = buffer(j) val calls = struct - .getStruct(genotypeFieldIndices.head, 2) + .getStruct(getGenotypeInfo.requiredFieldIndices.head, 2) .getArray(0) var k = 0 var isUncalled = false diff --git a/core/src/main/scala/io/projectglow/sql/expressions/VariantQcExprs.scala b/core/src/main/scala/io/projectglow/sql/expressions/VariantQcExprs.scala index be77901e1..96751a32e 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/VariantQcExprs.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/VariantQcExprs.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} import org.apache.spark.sql.types._ import io.projectglow.common.{GlowLogging, VariantSchemas} -import io.projectglow.sql.util.{ExpectsGenotypeFields, LeveneHaldane, Rewrite} +import io.projectglow.sql.util.{ExpectsGenotypeFields, GenotypeInfo, LeveneHaldane, Rewrite} /** * Contains implementations of QC functions. These implementations are called during both @@ -231,7 +231,11 @@ object VariantQcExprs extends GlowLogging { } } -case class HardyWeinberg(genotypes: Expression) extends UnaryExpression with ExpectsGenotypeFields { +case class HardyWeinberg(genotypes: Expression, genotypeInfo: Option[GenotypeInfo]) + extends UnaryExpression + with ExpectsGenotypeFields { + def this(genotypes: Expression) = this(genotypes, None) + override def dataType: DataType = StructType( Seq( @@ -242,24 +246,34 @@ case class HardyWeinberg(genotypes: Expression) extends UnaryExpression with Exp override def genotypesExpr: Expression = genotypes - override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.callsField) + override def requiredGenotypeFields: Seq[StructField] = Seq(VariantSchemas.callsField) + + override def withGenotypeInfo(genotypeInfo: GenotypeInfo): HardyWeinberg = { + copy(genotypeInfo = Some(genotypeInfo)) + } override def child: Expression = genotypes override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val fn = "io.projectglow.sql.expressions.VariantQcExprs.hardyWeinberg" - nullSafeCodeGen(ctx, ev, calls => { - s""" - |${ev.value} = $fn($calls, $genotypeStructSize, ${genotypeFieldIndices.head}); + nullSafeCodeGen( + ctx, + ev, + calls => { + s""" + |${ev.value} = $fn($calls, ${getGenotypeInfo.size}, ${getGenotypeInfo + .requiredFieldIndices + .head}); """.stripMargin - }) + } + ) } override def nullSafeEval(input: Any): Any = { VariantQcExprs.hardyWeinberg( input.asInstanceOf[ArrayData], - genotypeStructSize, - genotypeFieldIndices.head + getGenotypeInfo.size, + getGenotypeInfo.requiredFieldIndices.head ) } } @@ -270,29 +284,43 @@ object HardyWeinberg { case class HardyWeinbergStruct(hetFreqHwe: Double, pValueHwe: Double) -case class CallStats(genotypes: Expression) extends UnaryExpression with ExpectsGenotypeFields { +case class CallStats(genotypes: Expression, genotypeInfo: Option[GenotypeInfo]) + extends UnaryExpression + with ExpectsGenotypeFields { + def this(genotypes: Expression) = this(genotypes, None) + lazy val dataType: DataType = CallStats.schema override def genotypesExpr: Expression = genotypes - override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.callsField) + override def requiredGenotypeFields: Seq[StructField] = Seq(VariantSchemas.callsField) + + override def withGenotypeInfo(genotypeInfo: GenotypeInfo): CallStats = { + copy(genotypeInfo = Some(genotypeInfo)) + } override def child: Expression = genotypes override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val fn = "io.projectglow.sql.expressions.VariantQcExprs.callStats" - nullSafeCodeGen(ctx, ev, calls => { - s""" - |${ev.value} = $fn($calls, $genotypeStructSize, ${genotypeFieldIndices.head}); + nullSafeCodeGen( + ctx, + ev, + calls => { + s""" + |${ev.value} = $fn($calls, ${getGenotypeInfo.size}, ${getGenotypeInfo + .requiredFieldIndices + .head}); """.stripMargin - }) + } + ) } override def nullSafeEval(input: Any): Any = { VariantQcExprs.callStats( input.asInstanceOf[ArrayData], - genotypeStructSize, - genotypeFieldIndices.head + getGenotypeInfo.size, + getGenotypeInfo.requiredFieldIndices.head ) } } diff --git a/core/src/main/scala/io/projectglow/sql/expressions/VariantUtilExprs.scala b/core/src/main/scala/io/projectglow/sql/expressions/VariantUtilExprs.scala index 05f95ba43..0a18c87f2 100644 --- a/core/src/main/scala/io/projectglow/sql/expressions/VariantUtilExprs.scala +++ b/core/src/main/scala/io/projectglow/sql/expressions/VariantUtilExprs.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import io.projectglow.common.VariantSchemas -import io.projectglow.sql.util.ExpectsGenotypeFields +import io.projectglow.sql.util.{ExpectsGenotypeFields, GenotypeInfo} /** * Implementations of utility functions for transforming variant representations. These @@ -164,13 +164,19 @@ object VariantType { * of the calls array for the sample at that position if no calls are missing, or -1 if any calls * are missing. */ -case class GenotypeStates(genotypes: Expression) +case class GenotypeStates(genotypes: Expression, genotypeInfo: Option[GenotypeInfo]) extends UnaryExpression with ExpectsGenotypeFields { + def this(genotypes: Expression) = this(genotypes, None) + override def genotypesExpr: Expression = genotypes - override def genotypeFieldsRequired: Seq[StructField] = Seq(VariantSchemas.callsField) + override def requiredGenotypeFields: Seq[StructField] = Seq(VariantSchemas.callsField) + + override def withGenotypeInfo(genotypeInfo: GenotypeInfo): GenotypeStates = { + copy(genotypes, Some(genotypeInfo)) + } override def child: Expression = genotypes @@ -178,18 +184,24 @@ case class GenotypeStates(genotypes: Expression) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val fn = "io.projectglow.sql.expressions.VariantUtilExprs.genotypeStates" - nullSafeCodeGen(ctx, ev, calls => { - s""" - |${ev.value} = $fn($calls, $genotypeStructSize, ${genotypeFieldIndices.head}); + nullSafeCodeGen( + ctx, + ev, + calls => { + s""" + |${ev.value} = $fn($calls, ${getGenotypeInfo.size}, ${getGenotypeInfo + .requiredFieldIndices + .head}); """.stripMargin - }) + } + ) } override def nullSafeEval(input: Any): Any = { VariantUtilExprs.genotypeStates( input.asInstanceOf[ArrayData], - genotypeStructSize, - genotypeFieldIndices.head + getGenotypeInfo.size, + getGenotypeInfo.requiredFieldIndices.head ) } } diff --git a/core/src/main/scala/io/projectglow/sql/optimizer/hlsOptimizerRules.scala b/core/src/main/scala/io/projectglow/sql/optimizer/hlsOptimizerRules.scala index aa8d8da4a..1feb725d7 100644 --- a/core/src/main/scala/io/projectglow/sql/optimizer/hlsOptimizerRules.scala +++ b/core/src/main/scala/io/projectglow/sql/optimizer/hlsOptimizerRules.scala @@ -16,14 +16,14 @@ package io.projectglow.sql.optimizer -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAlias} import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import io.projectglow.common.GlowLogging import io.projectglow.sql.expressions._ -import io.projectglow.sql.util.RewriteAfterResolution +import io.projectglow.sql.util.{ExpectsGenotypeFields, RewriteAfterResolution} /** * Simple optimization rule that handles expression rewrites @@ -86,3 +86,17 @@ object ResolveExpandStructRule extends Rule[LogicalPlan] { } } } + +/** + * Resolve required genotype fields to their indices within the child expression. Performing + * this resolution explicitly guards against expressions like [[org.apache.spark.sql.catalyst.expressions.ArraysZip]] + * that can lose field names during physical planning. + */ +object ResolveGenotypeFields extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case e: ExpectsGenotypeFields + if !e.resolved && e.childrenResolved && e + .checkInputDataTypes() == TypeCheckResult.TypeCheckSuccess => + e.resolveGenotypeInfo() + } +} diff --git a/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala b/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala index 08e9d3b8b..13e1b96cf 100644 --- a/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala +++ b/core/src/main/scala/io/projectglow/sql/util/ExpectsGenotypeFields.scala @@ -21,36 +21,76 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExcept import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable} import org.apache.spark.sql.types.{ArrayType, DataType, StructField, StructType} +/** + * Stores the indices of required and optional fields within the genotype element struct after + * resolution. + * @param size The number of fields in the struct + * @param requiredFieldIndices The indices of required fields. 0 <= idx < size. + * @param optionalFieldIndices The indices of optional fields. -1 if not the field is not present. + */ +case class GenotypeInfo(size: Int, requiredFieldIndices: Seq[Int], optionalFieldIndices: Seq[Int]) + /** * A trait to simplify type checking and reading for expressions that operate on arrays of genotype * data with the expectation that certain fields exists. + * + * Note: This trait introduces complexity during resolution and analysis, and prevents + * nested column pruning. Prefer writing new functions as rewrites when possible. */ +@deprecated("Write functions as rewrites when possible", "0.4.1") trait ExpectsGenotypeFields extends Expression { + def genotypeInfo: Option[GenotypeInfo] + final def getGenotypeInfo: GenotypeInfo = { + genotypeInfo.get + } + + override lazy val resolved: Boolean = { + childrenResolved && genotypeInfo.isDefined + } + + /** + * Make a copy of this expression with [[GenotypeInfo]] filled in. + */ + protected def withGenotypeInfo(genotypeInfo: GenotypeInfo): Expression + + /** + * Resolve the required field names into positions within the [[genotypesExpr]] element struct. + * + * This function should only be called after a successful type check. + * + * @return A new expression with a defined [[GenotypeInfo]] + */ + def resolveGenotypeInfo(): Expression = { + val info = GenotypeInfo(genotypeStructSize, requiredFieldIndices, optionalFieldIndices) + withGenotypeInfo(info) + } + protected def genotypesExpr: Expression - protected def genotypeFieldsRequired: Seq[StructField] + protected def requiredGenotypeFields: Seq[StructField] protected def optionalGenotypeFields: Seq[StructField] = Seq.empty - private lazy val gStruct = genotypesExpr - .dataType - .asInstanceOf[ArrayType] - .elementType - .asInstanceOf[StructType] + private def gStruct = + genotypesExpr + .dataType + .asInstanceOf[ArrayType] + .elementType + .asInstanceOf[StructType] - protected lazy val genotypeFieldIndices: Seq[Int] = { - genotypeFieldsRequired.map { f => + private def requiredFieldIndices: Seq[Int] = { + requiredGenotypeFields.map { f => gStruct.indexWhere(SQLUtils.structFieldsEqualExceptNullability(f, _)) } } - protected lazy val optionalFieldIndices: Seq[Int] = { + private def optionalFieldIndices: Seq[Int] = { optionalGenotypeFields.map { f => gStruct.indexWhere(SQLUtils.structFieldsEqualExceptNullability(f, _)) } } - protected lazy val genotypeStructSize: Int = { + private def genotypeStructSize: Int = { gStruct.length } @@ -61,7 +101,7 @@ trait ExpectsGenotypeFields extends Expression { return TypeCheckResult.TypeCheckFailure("Genotypes field must be an array of structs") } - val missingFields = genotypeFieldsRequired.zip(genotypeFieldIndices).collect { + val missingFields = requiredGenotypeFields.zip(requiredFieldIndices).collect { case (f, -1) => f } diff --git a/core/src/test/scala/io/projectglow/sql/util/ExpectsGenotypeFieldsSuite.scala b/core/src/test/scala/io/projectglow/sql/util/ExpectsGenotypeFieldsSuite.scala new file mode 100644 index 000000000..bf828b2fe --- /dev/null +++ b/core/src/test/scala/io/projectglow/sql/util/ExpectsGenotypeFieldsSuite.scala @@ -0,0 +1,57 @@ +/* + * Copyright 2019 The Glow Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.projectglow.sql.util + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.functions._ + +import io.projectglow.Glow +import io.projectglow.functions._ +import io.projectglow.sql.GlowBaseTest + +class ExpectsGenotypeFieldsSuite extends GlowBaseTest { + lazy val gatkTestVcf = + s"$testDataHome/variantsplitternormalizer-test/test_left_align_hg38_altered.vcf" + lazy val sess = spark + + // This is how we originally detected an issue where ExpectsGenotypeFields succeeds during + // resolution but fails during physical planning. + // PR: https://github.com/projectglow/glow/pull/224 + test("use genotype_states after splitting multiallelics") { + val df = spark.read.format("vcf").load(gatkTestVcf) + val split = Glow.transform("split_multiallelics", df) + split.select(genotype_states(col("genotypes"))).collect() + } + + test("use genotype_states after array_zip") { + import sess.implicits._ + val df = spark + .createDataFrame(Seq((Seq("a"), Seq(Seq(1, 1))))) + .withColumnRenamed("_1", "sampleId") + .withColumnRenamed("_2", "calls") + val zipped = df.select(arrays_zip(col("sampleId"), col("calls")).as("genotypes")) + val states = zipped.select(genotype_states(col("genotypes"))) + assert(states.as[Seq[Int]].head == Seq(2)) + } + + test("type check") { + val df = spark.createDataFrame(Seq(Tuple1("a"))).withColumnRenamed("_1", "sampleId") + val withGenotypes = df.select(array(struct("sampleId")).as("genotypes")) + val ex = intercept[AnalysisException](withGenotypes.select(genotype_states(col("genotypes")))) + assert(ex.message.contains("Genotype struct was missing required fields: (name: calls")) + } +} diff --git a/core/src/test/scala/io/projectglow/transformers/splitmultiallelics/SplitMultiallelicsTransformerSuite.scala b/core/src/test/scala/io/projectglow/transformers/splitmultiallelics/SplitMultiallelicsTransformerSuite.scala index ce88e60ec..a52bfbb8a 100644 --- a/core/src/test/scala/io/projectglow/transformers/splitmultiallelics/SplitMultiallelicsTransformerSuite.scala +++ b/core/src/test/scala/io/projectglow/transformers/splitmultiallelics/SplitMultiallelicsTransformerSuite.scala @@ -21,6 +21,9 @@ import io.projectglow.common.VariantSchemas._ import io.projectglow.common.{CommonOptions, GlowLogging} import io.projectglow.sql.GlowBaseTest import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute + +import io.projectglow.sql.expressions.GenotypeStates import io.projectglow.transformers.splitmultiallelics.SplitMultiallelicsTransformer._ class SplitMultiallelicsTransformerSuite extends GlowBaseTest with GlowLogging { @@ -162,5 +165,4 @@ class SplitMultiallelicsTransformerSuite extends GlowBaseTest with GlowLogging { ) } - }