diff --git a/online/src/main/scala/ai/chronon/online/fetcher/JoinPartFetcher.scala b/online/src/main/scala/ai/chronon/online/fetcher/JoinPartFetcher.scala index ba43a89455..53642396ba 100644 --- a/online/src/main/scala/ai/chronon/online/fetcher/JoinPartFetcher.scala +++ b/online/src/main/scala/ai/chronon/online/fetcher/JoinPartFetcher.scala @@ -90,7 +90,7 @@ class JoinPartFetcher(fetchContext: FetchContext, metadataStore: MetadataStore) val rightKeys = part.leftToRight.map { case (leftKey, rightKey) => rightKey -> request.keys(leftKey) } Left( PrefixedRequest( - part.fullPrefix, + part.columnPrefix, Request(part.groupBy.getMetaData.getName, rightKeys, request.atMillis, Some(joinContextInner)))) } @@ -158,7 +158,7 @@ class JoinPartFetcher(fetchContext: FetchContext, metadataStore: MetadataStore) response .map { valueMap => if (valueMap != null) { - valueMap.map { case (aggName, aggValue) => prefix + "_" + aggName -> aggValue } + valueMap.map { case (aggName, aggValue) => prefix + aggName -> aggValue } } else { Map.empty[String, AnyRef] } @@ -169,7 +169,7 @@ class JoinPartFetcher(fetchContext: FetchContext, metadataStore: MetadataStore) if (fetchContext.debug || Math.random() < 0.001) { println(s"Failed to fetch $groupByRequest with \n${ex.traceString}") } - Map(prefix + "_exception" -> ex.traceString) + Map(prefix + "exception" -> ex.traceString) } .get } diff --git a/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala b/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala index 7ed57ec270..a86f63588a 100644 --- a/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala +++ b/online/src/test/scala/ai/chronon/online/test/FetcherBaseTest.scala @@ -211,7 +211,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M )) ) - val result = baseFetcher.parseGroupByResponse("prefix", request, response) + val result = baseFetcher.parseGroupByResponse("prefix_", request, response) assertEquals(result, Map("prefix_key" -> "value")) } @@ -227,7 +227,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M )) ) - val result = baseFetcher.parseGroupByResponse("prefix", request, response) + val result = baseFetcher.parseGroupByResponse("prefix_", request, response) result shouldBe Map() } @@ -243,7 +243,7 @@ class FetcherBaseTest extends AnyFlatSpec with MockitoSugar with Matchers with M )) ) - val result = baseFetcher.parseGroupByResponse("prefix", request, response) + val result = baseFetcher.parseGroupByResponse("prefix_", request, response) result.keySet shouldBe Set("prefix_exception") } diff --git a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherFailureTest.scala b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherFailureTest.scala index 01c432fdea..f3adeee77b 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherFailureTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherFailureTest.scala @@ -80,7 +80,7 @@ class FetcherFailureTest extends AnyFlatSpec { val request = Request(joinConf.metaData.name, keyMap) val (responses, _) = FetcherTestUtil.joinResponses(spark, Array(request), mockApi) val responseMap = responses.head.values.get - val exceptionKeys = joinConf.joinPartOps.map(jp => jp.fullPrefix + "_exception") + val exceptionKeys = joinConf.joinPartOps.map(jp => jp.columnPrefix + "exception") exceptionKeys.foreach(k => assertTrue(responseMap.contains(k))) } diff --git a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTest.scala index b42c868d12..53154dae01 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTest.scala @@ -102,8 +102,11 @@ class FetcherTest extends AnyFlatSpec { tableUtils.sql( s"SELECT * FROM $joinTable WHERE ts >= unix_timestamp('$endDs', '${tableUtils.partitionSpec.format}')") } - val endDsQueries = endDsEvents.drop(endDsEvents.schema.fieldNames.filter(_.contains("unit_test")): _*) + // Keep only left-side columns (keys, ts, ds) and drop all feature columns val keys = joinConf.leftKeyCols + val leftSideColumns = keys ++ Array(Constants.TimeColumn, tableUtils.partitionColumn) + val columnsToKeep = endDsEvents.schema.fieldNames.filter(leftSideColumns.contains) + val endDsQueries = endDsEvents.select(columnsToKeep.map(col): _*) val keyIndices = keys.map(endDsQueries.schema.fieldIndex) val tsIndex = endDsQueries.schema.fieldIndex(Constants.TimeColumn) val metadataStore = new fetcher.MetadataStore(FetchContext(inMemoryKvStore)) diff --git a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTestUtil.scala b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTestUtil.scala index ccf26b43b6..a0f02054fb 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTestUtil.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/fetcher/FetcherTestUtil.scala @@ -440,29 +440,35 @@ object FetcherTestUtil { queriesDf.show() queriesDf.save(queriesTable) - val joinConf = Builders.Join( - left = Builders.Source.events(Builders.Query(startPartition = today), table = queriesTable), - joinParts = Seq( - Builders.JoinPart(groupBy = vendorRatingsGroupBy, keyMapping = Map("vendor_id" -> "vendor")), - Builders.JoinPart(groupBy = userPaymentsGroupBy, keyMapping = Map("user_id" -> "user")), - Builders.JoinPart(groupBy = userBalanceGroupBy, keyMapping = Map("user_id" -> "user")), - Builders.JoinPart(groupBy = reviewGroupBy), - Builders.JoinPart(groupBy = creditGroupBy, prefix = "b"), - Builders.JoinPart(groupBy = creditGroupBy, prefix = "a"), - Builders.JoinPart(groupBy = creditDerivationGroupBy, prefix = "c") - ), - metaData = Builders.MetaData(name = "test.payments_join", - namespace = namespace, - team = "chronon", - consistencySamplePercent = 30), - derivations = Seq( - Builders.Derivation("*", "*"), - Builders.Derivation("hist_3d", "unit_test_vendor_ratings_txn_types_histogram_3d"), - Builders.Derivation("payment_variance", "unit_test_user_payments_payment_variance/2"), - Builders.Derivation("derived_ds", "from_unixtime(ts/1000, 'yyyy-MM-dd')"), - Builders.Derivation("direct_ds", "ds") + val joinConf = Builders + .Join( + left = Builders.Source.events(Builders.Query(startPartition = today), table = queriesTable), + joinParts = Seq( + Builders + .JoinPart(groupBy = vendorRatingsGroupBy, keyMapping = Map("vendor_id" -> "vendor")) + .setUseLongNames(false), + Builders + .JoinPart(groupBy = userPaymentsGroupBy, keyMapping = Map("user_id" -> "user")) + .setUseLongNames(false), + Builders.JoinPart(groupBy = userBalanceGroupBy, keyMapping = Map("user_id" -> "user")).setUseLongNames(false), + Builders.JoinPart(groupBy = reviewGroupBy).setUseLongNames(false), + Builders.JoinPart(groupBy = creditGroupBy, prefix = "b").setUseLongNames(false), + Builders.JoinPart(groupBy = creditGroupBy, prefix = "a").setUseLongNames(false), + Builders.JoinPart(groupBy = creditDerivationGroupBy, prefix = "c").setUseLongNames(false) + ), + metaData = Builders.MetaData(name = "test.payments_join", + namespace = namespace, + team = "chronon", + consistencySamplePercent = 30), + derivations = Seq( + Builders.Derivation("*", "*"), + Builders.Derivation("hist_3d", "vendor_txn_types_histogram_3d"), + Builders.Derivation("payment_variance", "user_payment_variance/2"), + Builders.Derivation("derived_ds", "from_unixtime(ts/1000, 'yyyy-MM-dd')"), + Builders.Derivation("direct_ds", "ds") + ) ) - ) + .setUseLongNames(false) joinConf }