Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions api/py/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@

# Constants for supporting multiple spark versions.
SUPPORTED_SPARK = ["2.4.0", "3.1.1", "3.2.1", "3.5.1"]
SCALA_VERSION_FOR_SPARK = {"2.4.0": "2.11",
"3.1.1": "2.12", "3.2.1": "2.13", "3.5.1": "2.12"}
SCALA_VERSION_FOR_SPARK = {"2.4.0": "2.11", "3.1.1": "2.12", "3.2.1": "2.13", "3.5.1": "2.12"}

MODE_ARGS = {
"backfill": OFFLINE_ARGS,
Expand Down Expand Up @@ -407,9 +406,11 @@ def __init__(self, args, jar_path):
self.sub_help = args["sub_help"]
self.mode = args["mode"]
self.online_jar = args["online_jar"]
self.dataproc = args["dataproc"]
valid_jar = args["online_jar"] and os.path.exists(args["online_jar"])

# fetch online jar if necessary
if (self.mode in ONLINE_MODES) and (not args["sub_help"]) and not valid_jar:
if (self.mode in ONLINE_MODES) and (not args["sub_help"] and not self.dataproc) and not valid_jar:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like we might want a flag that's not just for dataproc but whether this command is an offline batch run? Not critically blocking though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's getting pretty crowded in the code here.

command is an offline batch run
I think that's what the opposite of if (self.mode in ONLINE_MODES) is supposed to represent. but then there's whether we submit to dataproc or not (spark-submit).

the sub help part is one we probably could catch much earlier tbh.

print("Downloading online_jar")
self.online_jar = check_output("{}".format(args["online_jar_fetch"])).decode(
"utf-8"
Expand All @@ -434,8 +435,7 @@ def __init__(self, args, jar_path):
args["mode"] in possible_modes), ("Invalid mode:{} for conf:{} of type:{}, please choose from {}"
.format(args["mode"], self.conf, self.conf_type, possible_modes
))
else:
self.conf_type = args["conf_type"]

self.ds = args["end_ds"] if "end_ds" in args and args["end_ds"] else args["ds"]
self.start_ds = (
args["start_ds"] if "start_ds" in args and args["start_ds"] else None
Expand All @@ -459,7 +459,6 @@ def __init__(self, args, jar_path):
else:
self.spark_submit = args["spark_submit_path"]
self.list_apps_cmd = args["list_apps"]
self.dataproc = args["dataproc"]

def run(self):
with tempfile.TemporaryDirectory() as temp_dir:
Expand All @@ -470,7 +469,7 @@ def run(self):
script=self.render_info, conf=self.conf, ds=self.ds, repo=self.repo
)
)
elif self.sub_help or (self.mode not in SPARK_MODES):
elif (self.sub_help or (self.mode not in SPARK_MODES)) and not self.dataproc:
command_list.append(
"java -cp {jar} ai.chronon.spark.Driver {subcommand} {args}".format(
jar=self.jar_path,
Expand Down Expand Up @@ -660,6 +659,8 @@ def _gen_final_args(self, start_ds=None, end_ds=None, override_conf_path=None):
online_jar=self.online_jar,
online_class=self.online_class,
)
base_args = base_args + f" --conf-type={self.conf_type} " if self.conf_type else base_args

override_start_partition_arg = (
" --start-partition-override=" + start_ds if start_ds else ""
)
Expand Down Expand Up @@ -867,7 +868,7 @@ def main(ctx, conf, env, mode, dataproc, ds, app_name, start_ds, end_ds, paralle
set_runtime_env(ctx.params)
set_defaults(ctx)
jar_type = "embedded" if mode in MODES_USING_EMBEDDED else "uber"
extra_args = (" " + online_args) if mode in ONLINE_MODES else ""
extra_args = (" " + online_args) if mode in ONLINE_MODES and online_args else ""
ctx.params["args"] = " ".join(unknown_args) + extra_args
jar_path = (
chronon_jar
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient,
// we can explore split points if we need custom tablet partitioning. For now though, we leave this to BT
val createTableRequest = CreateTableRequest.of(dataset).addFamily(ColumnFamilyString, DefaultGcRules)
val table = adminClient.createTable(createTableRequest)

// TODO: this actually submits an async task. thus, the submission can succeed but the task can fail.
// doesn't return a future but maybe we can poll
logger.info(s"Created table: $table")
metricsContext.increment("create.successes")

Expand Down
27 changes: 18 additions & 9 deletions online/src/main/scala/ai/chronon/online/MetadataDirWalker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ import java.nio.file.Paths
import scala.reflect.ClassTag
import scala.util.Try

class MetadataDirWalker(dirPath: String, metadataEndPointNames: List[String]) {
class MetadataDirWalker(dirPath: String, metadataEndPointNames: List[String], maybeConfType: Option[String] = None) {

val JoinKeyword = "joins"
val GroupByKeyword = "group_bys"
val StagingQueryKeyword = "staging_queries"
val ModelKeyword = "models"

@transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass)
private def loadJsonToConf[T <: TBase[_, _]: Manifest: ClassTag](file: String): Option[T] = {
Expand Down Expand Up @@ -77,10 +82,14 @@ class MetadataDirWalker(dirPath: String, metadataEndPointNames: List[String]) {
val optConf =
try {
filePath match {
case value if value.contains("joins/") => loadJsonToConf[api.Join](filePath)
case value if value.contains("group_bys/") => loadJsonToConf[api.GroupBy](filePath)
case value if value.contains("staging_queries/") => loadJsonToConf[api.StagingQuery](filePath)
case value if value.contains("models/") => loadJsonToConf[api.Model](filePath)
case value if value.contains(s"$JoinKeyword/") || maybeConfType.contains(JoinKeyword) =>
loadJsonToConf[api.Join](filePath)
case value if value.contains(s"$GroupByKeyword/") || maybeConfType.contains(GroupByKeyword) =>
loadJsonToConf[api.GroupBy](filePath)
case value if value.contains(s"$StagingQueryKeyword/") || maybeConfType.contains(StagingQueryKeyword) =>
loadJsonToConf[api.StagingQuery](filePath)
case value if value.contains(s"$ModelKeyword/") || maybeConfType.contains(ModelKeyword) =>
loadJsonToConf[api.Model](filePath)
Comment on lines +85 to +92
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doing this because the filePath we're passing in from Dataproc doesn't have the full path because when we configure a job to be submitted with Dataproc, we add the GCS uploaded config but Dataproc places that file in the working directory (not with the full path):

https://cloud.google.com/dataproc/docs/reference/rest/v1/SparkJob
fileUris[] | stringOptional. HCFS URIs of files to be placed in the working directory of each executor. Useful for naively parallel tasks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is where we strip the filepath for just the filename in rrun.py. https://github.com/zipline-ai/chronon/blob/main/api/py/ai/chronon/repo/run.py#L575-L580

has some docs

}
} catch {
case e: Throwable =>
Expand All @@ -94,22 +103,22 @@ class MetadataDirWalker(dirPath: String, metadataEndPointNames: List[String]) {
val conf = optConf.get

val kVPair = filePath match {
case value if value.contains("joins/") =>
case value if value.contains(s"$JoinKeyword/") || maybeConfType.contains(JoinKeyword) =>
MetadataEndPoint
.getEndPoint[api.Join](endPointName)
.extractFn(filePath, conf.asInstanceOf[api.Join])

case value if value.contains("group_bys/") =>
case value if value.contains(s"$GroupByKeyword/") || maybeConfType.contains(GroupByKeyword) =>
MetadataEndPoint
.getEndPoint[api.GroupBy](endPointName)
.extractFn(filePath, conf.asInstanceOf[api.GroupBy])

case value if value.contains("staging_queries/") =>
case value if value.contains(s"$StagingQueryKeyword/") || maybeConfType.contains(StagingQueryKeyword) =>
MetadataEndPoint
.getEndPoint[api.StagingQuery](endPointName)
.extractFn(filePath, conf.asInstanceOf[api.StagingQuery])

case value if value.contains("models/") =>
case value if value.contains(s"$ModelKeyword/") || maybeConfType.contains(ModelKeyword) =>
MetadataEndPoint
.getEndPoint[api.Model](endPointName)
.extractFn(filePath, conf.asInstanceOf[api.Model])
Expand Down
3 changes: 3 additions & 0 deletions online/src/main/scala/ai/chronon/online/MetadataStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ class MetadataStore(kvStore: KVStore, val dataset: String = MetadataDataset, tim
def create(dataset: String): Unit = {
try {
logger.info(s"Creating dataset: $dataset")
// TODO: this is actually just an async task. it doesn't block and thus we don't actually
// know if it successfully created the dataset
Comment on lines +235 to +236
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piyush-zlai - was thinking I could poll? the issue is this doesn't return a future

kvStore.create(dataset)

logger.info(s"Successfully created dataset: $dataset")
} catch {
case e: Exception =>
Expand Down
13 changes: 8 additions & 5 deletions spark/src/main/scala/ai/chronon/spark/Driver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,20 @@ object Driver {
def parseConf[T <: TBase[_, _]: Manifest: ClassTag](confPath: String): T =
ThriftJsonCodec.fromJsonFile[T](confPath, check = true)

trait AddGcpSubCommandArgs {
trait SharedSubCommandArgs {
this: ScallopConf =>
val isGcp: ScallopOption[Boolean] =
opt[Boolean](required = false, default = Some(false), descr = "Whether to use GCP")
val gcpProjectId: ScallopOption[String] =
opt[String](required = false, descr = "GCP project id")
val gcpBigtableInstanceId: ScallopOption[String] =
opt[String](required = false, descr = "GCP BigTable instance id")

val confType: ScallopOption[String] =
opt[String](required = false, descr = "Type of the conf to run. ex: join, group-by, etc")
Comment on lines +98 to +99
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mentioned earlier but this is needed because there's logic in scala that depends on the conf file path having keywords like .../joins/... etc and extracting the conf type from the path. Since we don't have the full file path (see above) then we have to use confType which is set in run.py

}

trait OfflineSubcommand extends AddGcpSubCommandArgs {
trait OfflineSubcommand extends SharedSubCommandArgs {
this: ScallopConf =>
val confPath: ScallopOption[String] = opt[String](required = true, descr = "Path to conf")

Expand Down Expand Up @@ -584,7 +587,7 @@ object Driver {
}

// common arguments to all online commands
trait OnlineSubcommand extends AddGcpSubCommandArgs { s: ScallopConf =>
trait OnlineSubcommand extends SharedSubCommandArgs { s: ScallopConf =>
// this is `-Z` and not `-D` because sbt-pack plugin uses that for JAVA_OPTS
val propsInner: Map[String, String] = props[String]('Z')
val onlineJar: ScallopOption[String] =
Expand Down Expand Up @@ -615,7 +618,7 @@ object Driver {
}

def metaDataStore =
new MetadataStore(impl(serializableProps).genKvStore, MetadataDataset, timeoutMillis = 10000)
new MetadataStore(api.genKvStore, MetadataDataset, timeoutMillis = 10000)

def impl(props: Map[String, String]): Api = {
val urls = Array(new File(onlineJar()).toURI.toURL)
Expand Down Expand Up @@ -748,7 +751,7 @@ object Driver {

def run(args: Args): Unit = {
val acceptedEndPoints = List(MetadataEndPoint.ConfByKeyEndPointName, MetadataEndPoint.NameByTeamEndPointName)
val dirWalker = new MetadataDirWalker(args.confPath(), acceptedEndPoints)
val dirWalker = new MetadataDirWalker(args.confPath(), acceptedEndPoints, maybeConfType = args.confType.toOption)
val kvMap: Map[String, Map[String, List[String]]] = dirWalker.run

// trigger creates of the datasets before we proceed with writes
Expand Down
Loading