Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion api/python/test/canary/group_bys/gcp/purchases.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,71 @@
operation=Operation.LAST_K(10),
),
],
)
)

# This source is raw purchase events. Every time a user makes a purchase, it will be one entry in this source.
source_notds = Source(
events=EventSource(
table="data.purchases_notds", # This points to the log table in the warehouse with historical purchase events, updated in batch daily
topic=None, # See the 'returns' GroupBy for an example that has a streaming source configured. In this case, this would be the streaming source topic that can be listened to for realtime events
query=Query(
selects=selects("user_id","purchase_price"), # Select the fields we care about
time_column="ts",
partition_column="notds"
) # The event time
))

v1_test_notds = GroupBy(
backfill_start_date="2023-11-01",
sources=[source_notds],
keys=["user_id"], # We are aggregating by user
online=True,
aggregations=[Aggregation(
input_column="purchase_price",
operation=Operation.SUM,
windows=window_sizes
), # The sum of purchases prices in various windows
Aggregation(
input_column="purchase_price",
operation=Operation.COUNT,
windows=window_sizes
), # The count of purchases in various windows
Aggregation(
input_column="purchase_price",
operation=Operation.AVERAGE,
windows=window_sizes
), # The average purchases by user in various windows
Aggregation(
input_column="purchase_price",
operation=Operation.LAST_K(10),
),
],
)

v1_dev_notds = GroupBy(
backfill_start_date="2023-11-01",
sources=[source_notds],
keys=["user_id"], # We are aggregating by user
online=True,
aggregations=[Aggregation(
input_column="purchase_price",
operation=Operation.SUM,
windows=window_sizes
), # The sum of purchases prices in various windows
Aggregation(
input_column="purchase_price",
operation=Operation.COUNT,
windows=window_sizes
), # The count of purchases in various windows
Aggregation(
input_column="purchase_price",
operation=Operation.AVERAGE,
windows=window_sizes
), # The average purchases by user in various windows
Aggregation(
input_column="purchase_price",
operation=Operation.LAST_K(10),
),
],
)

33 changes: 30 additions & 3 deletions api/python/test/canary/joins/gcp/training_set.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from group_bys.gcp.purchases import v1_dev, v1_test
from group_bys.gcp import purchases

from ai.chronon.api.ttypes import EventSource, Source
from ai.chronon.join import Join, JoinPart
Expand All @@ -23,13 +23,40 @@
v1_test = Join(
left=source,
right_parts=[
JoinPart(group_by=v1_test)
JoinPart(group_by=purchases.v1_test)
],
)

v1_dev = Join(
left=source,
right_parts=[
JoinPart(group_by=v1_dev)
JoinPart(group_by=purchases.v1_dev)
],
)

source_notds = Source(
events=EventSource(
table="data.checkouts_notds",
query=Query(
selects=selects(
"user_id"
), # The primary key used to join various GroupBys together
time_column="ts",
partition_column="notds"
), # The event time used to compute feature values as-of
)
)

v1_test_notds = Join(
left=source_notds,
right_parts=[
JoinPart(group_by=purchases.v1_test_notds)
],
)

v1_dev_notds = Join(
left=source_notds,
right_parts=[
JoinPart(group_by=purchases.v1_dev_notds)
],
)
10 changes: 9 additions & 1 deletion scripts/distribution/run_gcp_quickstart.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,16 @@ if [[ "$ENVIRONMENT" == "canary" ]]; then
bq rm -f -t canary-443022:data.gcp_purchases_v1_view_test
bq rm -f -t canary-443022:data.gcp_purchases_v1_test_upload
bq rm -f -t canary-443022:data.gcp_training_set_v1_test
bq rm -f -t canary-443022:data.gcp_purchases_v1_test_notds
bq rm -f -t canary-443022:data.gcp_training_set_v1_test_notds

else
bq rm -f -t canary-443022:data.gcp_purchases_v1_dev
bq rm -f -t canary-443022:data.gcp_purchases_v1_view_dev
bq rm -f -t canary-443022:data.gcp_purchases_v1_dev_upload
bq rm -f -t canary-443022:data.gcp_training_set_v1_dev
bq rm -f -t canary-443022:data.gcp_purchases_v1_dev_notds
bq rm -f -t canary-443022:data.gcp_training_set_v1_dev_notds
fi
#TODO: delete bigtable rows

Expand Down Expand Up @@ -148,12 +153,15 @@ fail_if_bash_failed $?
echo -e "${GREEN}<<<<<.....................................BACKFILL-JOIN.....................................>>>>>\033[0m"
if [[ "$ENVIRONMENT" == "canary" ]]; then
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/joins/gcp/training_set.v1_test --start-ds 2023-11-01 --end-ds 2023-12-01
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/joins/gcp/training_set.v1_test_notds --start-ds 2023-11-01 --end-ds 2023-12-01

else
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/joins/gcp/training_set.v1_dev --start-ds 2023-11-01 --end-ds 2023-12-01
zipline run --repo=$CHRONON_ROOT --version $VERSION --mode backfill --conf compiled/joins/gcp/training_set.v1_dev_notds --start-ds 2023-11-01 --end-ds 2023-12-01
fi

fail_if_bash_failed $?


echo -e "${GREEN}<<<<<.....................................CHECK-PARTITIONS.....................................>>>>>\033[0m"
EXPECTED_PARTITION="2023-11-30"
if [[ "$ENVIRONMENT" == "canary" ]]; then
Expand Down
15 changes: 7 additions & 8 deletions spark/src/main/scala/ai/chronon/spark/JoinUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,11 @@ object JoinUtils {
val effectiveLeftSpec = joinConf.left.partitionSpec
val effectiveLeftRange = range.translate(effectiveLeftSpec)

val partitionColumnOfLeft = effectiveLeftSpec.column

var df = tableUtils.scanDf(joinConf.left.query,
joinConf.left.table,
Some((Map(tableUtils.partitionColumn -> null) ++ timeProjection).toMap),
Some((Map(partitionColumnOfLeft -> null) ++ timeProjection).toMap),
range = Some(effectiveLeftRange))

limit.foreach(l => df = df.limit(l))
Expand Down Expand Up @@ -171,7 +173,7 @@ object JoinUtils {
val leftEnd = Option(leftSource.query.endPartition).getOrElse(endPartition)

logger.info(s"Attempting to fill join partition range: $leftStart to $leftEnd")
PartitionRange(leftStart, leftEnd)(tableUtils.partitionSpec)
PartitionRange(leftStart, leftEnd)(leftSpec)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might want to default to the tableUtils.partitionSpec if there's no leftSpec?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that auto magically happens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Implicit boy"

}

/** *
Expand Down Expand Up @@ -323,9 +325,8 @@ object JoinUtils {
groupByKeyExpressions
.map { case (keyName, groupByKeyExpression) =>
val leftSideKeyName = joinPart.rightToLeft(keyName)
logger.info(
s"KeyName: $keyName, leftSide KeyName: $leftSideKeyName , Join right to left: ${joinPart.rightToLeft
.mkString(", ")}")
logger.info(s"KeyName: $keyName, leftSide KeyName: $leftSideKeyName , " +
s"Join right to left: ${joinPart.rightToLeft.mkString(", ")}")
val values = collectedLeft.map(row => row.getAs[Any](leftSideKeyName))
// Check for null keys, warn if found, err if all null
val (notNullValues, nullValues) = values.partition(_ != null)
Expand Down Expand Up @@ -492,9 +493,7 @@ object JoinUtils {
}

def parseSkewKeys(jmap: java.util.Map[String, java.util.List[String]]): Option[Map[String, Seq[String]]] = {
Option(jmap).map(_.toScala.map { case (key, list) =>
key -> list.asScala
}.toMap)
Option(jmap).map(_.toScala.map { case (key, list) => key -> list.asScala }.toMap)
}

def shiftDays(leftDataModel: DataModel, joinPart: JoinPart, leftRange: PartitionRange): PartitionRange = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ class StagingQuery(stagingQueryConf: api.StagingQuery, endPartition: String, tab
stagingQueryUnfilledRanges.foreach { stagingQueryUnfilledRange =>
try {
val stepRanges = stepDays.map(stagingQueryUnfilledRange.steps).getOrElse(Seq(stagingQueryUnfilledRange))
logger.info(s"Staging query ranges to compute: ${stepRanges.map {
_.toString
}.pretty}")
logger.info(s"Staging query ranges to compute: ${stepRanges.map { _.toString }.pretty}")
stepRanges.zipWithIndex.foreach { case (range, index) =>
val progress = s"| [${index + 1}/${stepRanges.size}]"
logger.info(s"Computing staging query for range: $range $progress")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,15 @@ class TableUtils(@transient val sparkSession: SparkSession) extends Serializable
if (!tableReachable(tableName)) return List.empty[String]
val rangeWheres = andPredicates(partitionRange.map(_.whereClauses).getOrElse(Seq.empty))

val effectivePartColumn = tablePartitionSpec.map(_.column).getOrElse(partitionColumnName)
Copy link
Collaborator

@tchow-zlai tchow-zlai May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might want to prioritize the spec in the range, then the table spec, then the default spec?


val partitions = tableFormatProvider
.readFormat(tableName)
.map((format) => {
logger.info(
s"Getting partitions for ${tableName} with partitionColumnName ${partitionColumnName} and subpartitions: ${subPartitionsFilter}")
s"Getting partitions for ${tableName} with partitionColumnName ${effectivePartColumn} and subpartitions: ${subPartitionsFilter}")
val partitions =
format.primaryPartitions(tableName, partitionColumnName, rangeWheres, subPartitionsFilter)(sparkSession)
format.primaryPartitions(tableName, effectivePartColumn, rangeWheres, subPartitionsFilter)(sparkSession)

if (partitions.isEmpty) {
logger.info(s"No partitions found for table: $tableName with subpartition filters ${subPartitionsFilter}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import ai.chronon.spark._
import ai.chronon.spark.catalog.{Format, IncompatibleSchemaException}
import ai.chronon.spark.test.TestUtils.makeDf
import org.apache.hadoop.hive.ql.exec.UDF
import org.apache.spark.sql.{Row, _}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.{Row, _}
import org.junit.Assert.{assertEquals, assertNull, assertTrue}
import org.scalatest.flatspec.AnyFlatSpec

Expand Down