diff --git a/api/py/ai/chronon/repo/run.py b/api/py/ai/chronon/repo/run.py index c066701c3e..cce4b17499 100755 --- a/api/py/ai/chronon/repo/run.py +++ b/api/py/ai/chronon/repo/run.py @@ -64,8 +64,7 @@ # Constants for supporting multiple spark versions. SUPPORTED_SPARK = ["2.4.0", "3.1.1", "3.2.1", "3.5.1"] -SCALA_VERSION_FOR_SPARK = {"2.4.0": "2.11", - "3.1.1": "2.12", "3.2.1": "2.13", "3.5.1": "2.12"} +SCALA_VERSION_FOR_SPARK = {"2.4.0": "2.11", "3.1.1": "2.12", "3.2.1": "2.13", "3.5.1": "2.12"} MODE_ARGS = { "backfill": OFFLINE_ARGS, @@ -407,9 +406,11 @@ def __init__(self, args, jar_path): self.sub_help = args["sub_help"] self.mode = args["mode"] self.online_jar = args["online_jar"] + self.dataproc = args["dataproc"] valid_jar = args["online_jar"] and os.path.exists(args["online_jar"]) + # fetch online jar if necessary - if (self.mode in ONLINE_MODES) and (not args["sub_help"]) and not valid_jar: + if (self.mode in ONLINE_MODES) and (not args["sub_help"] and not self.dataproc) and not valid_jar: print("Downloading online_jar") self.online_jar = check_output("{}".format(args["online_jar_fetch"])).decode( "utf-8" @@ -434,8 +435,7 @@ def __init__(self, args, jar_path): args["mode"] in possible_modes), ("Invalid mode:{} for conf:{} of type:{}, please choose from {}" .format(args["mode"], self.conf, self.conf_type, possible_modes )) - else: - self.conf_type = args["conf_type"] + self.ds = args["end_ds"] if "end_ds" in args and args["end_ds"] else args["ds"] self.start_ds = ( args["start_ds"] if "start_ds" in args and args["start_ds"] else None @@ -459,7 +459,6 @@ def __init__(self, args, jar_path): else: self.spark_submit = args["spark_submit_path"] self.list_apps_cmd = args["list_apps"] - self.dataproc = args["dataproc"] def run(self): with tempfile.TemporaryDirectory() as temp_dir: @@ -470,7 +469,7 @@ def run(self): script=self.render_info, conf=self.conf, ds=self.ds, repo=self.repo ) ) - elif self.sub_help or (self.mode not in SPARK_MODES): + elif (self.sub_help or (self.mode not in SPARK_MODES)) and not self.dataproc: command_list.append( "java -cp {jar} ai.chronon.spark.Driver {subcommand} {args}".format( jar=self.jar_path, @@ -660,6 +659,8 @@ def _gen_final_args(self, start_ds=None, end_ds=None, override_conf_path=None): online_jar=self.online_jar, online_class=self.online_class, ) + base_args = base_args + f" --conf-type={self.conf_type} " if self.conf_type else base_args + override_start_partition_arg = ( " --start-partition-override=" + start_ds if start_ds else "" ) @@ -867,7 +868,7 @@ def main(ctx, conf, env, mode, dataproc, ds, app_name, start_ds, end_ds, paralle set_runtime_env(ctx.params) set_defaults(ctx) jar_type = "embedded" if mode in MODES_USING_EMBEDDED else "uber" - extra_args = (" " + online_args) if mode in ONLINE_MODES else "" + extra_args = (" " + online_args) if mode in ONLINE_MODES and online_args else "" ctx.params["args"] = " ".join(unknown_args) + extra_args jar_path = ( chronon_jar diff --git a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreImpl.scala b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreImpl.scala index 96e728269d..26f3c8b701 100644 --- a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreImpl.scala +++ b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/BigTableKVStoreImpl.scala @@ -96,7 +96,8 @@ class BigTableKVStoreImpl(dataClient: BigtableDataClient, // we can explore split points if we need custom tablet partitioning. For now though, we leave this to BT val createTableRequest = CreateTableRequest.of(dataset).addFamily(ColumnFamilyString, DefaultGcRules) val table = adminClient.createTable(createTableRequest) - + // TODO: this actually submits an async task. thus, the submission can succeed but the task can fail. + // doesn't return a future but maybe we can poll logger.info(s"Created table: $table") metricsContext.increment("create.successes") diff --git a/online/src/main/scala/ai/chronon/online/MetadataDirWalker.scala b/online/src/main/scala/ai/chronon/online/MetadataDirWalker.scala index f712b55231..3323f91639 100644 --- a/online/src/main/scala/ai/chronon/online/MetadataDirWalker.scala +++ b/online/src/main/scala/ai/chronon/online/MetadataDirWalker.scala @@ -15,7 +15,12 @@ import java.nio.file.Paths import scala.reflect.ClassTag import scala.util.Try -class MetadataDirWalker(dirPath: String, metadataEndPointNames: List[String]) { +class MetadataDirWalker(dirPath: String, metadataEndPointNames: List[String], maybeConfType: Option[String] = None) { + + val JoinKeyword = "joins" + val GroupByKeyword = "group_bys" + val StagingQueryKeyword = "staging_queries" + val ModelKeyword = "models" @transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass) private def loadJsonToConf[T <: TBase[_, _]: Manifest: ClassTag](file: String): Option[T] = { @@ -77,10 +82,14 @@ class MetadataDirWalker(dirPath: String, metadataEndPointNames: List[String]) { val optConf = try { filePath match { - case value if value.contains("joins/") => loadJsonToConf[api.Join](filePath) - case value if value.contains("group_bys/") => loadJsonToConf[api.GroupBy](filePath) - case value if value.contains("staging_queries/") => loadJsonToConf[api.StagingQuery](filePath) - case value if value.contains("models/") => loadJsonToConf[api.Model](filePath) + case value if value.contains(s"$JoinKeyword/") || maybeConfType.contains(JoinKeyword) => + loadJsonToConf[api.Join](filePath) + case value if value.contains(s"$GroupByKeyword/") || maybeConfType.contains(GroupByKeyword) => + loadJsonToConf[api.GroupBy](filePath) + case value if value.contains(s"$StagingQueryKeyword/") || maybeConfType.contains(StagingQueryKeyword) => + loadJsonToConf[api.StagingQuery](filePath) + case value if value.contains(s"$ModelKeyword/") || maybeConfType.contains(ModelKeyword) => + loadJsonToConf[api.Model](filePath) } } catch { case e: Throwable => @@ -94,22 +103,22 @@ class MetadataDirWalker(dirPath: String, metadataEndPointNames: List[String]) { val conf = optConf.get val kVPair = filePath match { - case value if value.contains("joins/") => + case value if value.contains(s"$JoinKeyword/") || maybeConfType.contains(JoinKeyword) => MetadataEndPoint .getEndPoint[api.Join](endPointName) .extractFn(filePath, conf.asInstanceOf[api.Join]) - case value if value.contains("group_bys/") => + case value if value.contains(s"$GroupByKeyword/") || maybeConfType.contains(GroupByKeyword) => MetadataEndPoint .getEndPoint[api.GroupBy](endPointName) .extractFn(filePath, conf.asInstanceOf[api.GroupBy]) - case value if value.contains("staging_queries/") => + case value if value.contains(s"$StagingQueryKeyword/") || maybeConfType.contains(StagingQueryKeyword) => MetadataEndPoint .getEndPoint[api.StagingQuery](endPointName) .extractFn(filePath, conf.asInstanceOf[api.StagingQuery]) - case value if value.contains("models/") => + case value if value.contains(s"$ModelKeyword/") || maybeConfType.contains(ModelKeyword) => MetadataEndPoint .getEndPoint[api.Model](endPointName) .extractFn(filePath, conf.asInstanceOf[api.Model]) diff --git a/online/src/main/scala/ai/chronon/online/MetadataStore.scala b/online/src/main/scala/ai/chronon/online/MetadataStore.scala index c072a77ff0..f98e04a41a 100644 --- a/online/src/main/scala/ai/chronon/online/MetadataStore.scala +++ b/online/src/main/scala/ai/chronon/online/MetadataStore.scala @@ -232,7 +232,10 @@ class MetadataStore(kvStore: KVStore, val dataset: String = MetadataDataset, tim def create(dataset: String): Unit = { try { logger.info(s"Creating dataset: $dataset") + // TODO: this is actually just an async task. it doesn't block and thus we don't actually + // know if it successfully created the dataset kvStore.create(dataset) + logger.info(s"Successfully created dataset: $dataset") } catch { case e: Exception => diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 73d5680f08..471e330a50 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -86,7 +86,7 @@ object Driver { def parseConf[T <: TBase[_, _]: Manifest: ClassTag](confPath: String): T = ThriftJsonCodec.fromJsonFile[T](confPath, check = true) - trait AddGcpSubCommandArgs { + trait SharedSubCommandArgs { this: ScallopConf => val isGcp: ScallopOption[Boolean] = opt[Boolean](required = false, default = Some(false), descr = "Whether to use GCP") @@ -94,9 +94,12 @@ object Driver { opt[String](required = false, descr = "GCP project id") val gcpBigtableInstanceId: ScallopOption[String] = opt[String](required = false, descr = "GCP BigTable instance id") + + val confType: ScallopOption[String] = + opt[String](required = false, descr = "Type of the conf to run. ex: join, group-by, etc") } - trait OfflineSubcommand extends AddGcpSubCommandArgs { + trait OfflineSubcommand extends SharedSubCommandArgs { this: ScallopConf => val confPath: ScallopOption[String] = opt[String](required = true, descr = "Path to conf") @@ -584,7 +587,7 @@ object Driver { } // common arguments to all online commands - trait OnlineSubcommand extends AddGcpSubCommandArgs { s: ScallopConf => + trait OnlineSubcommand extends SharedSubCommandArgs { s: ScallopConf => // this is `-Z` and not `-D` because sbt-pack plugin uses that for JAVA_OPTS val propsInner: Map[String, String] = props[String]('Z') val onlineJar: ScallopOption[String] = @@ -615,7 +618,7 @@ object Driver { } def metaDataStore = - new MetadataStore(impl(serializableProps).genKvStore, MetadataDataset, timeoutMillis = 10000) + new MetadataStore(api.genKvStore, MetadataDataset, timeoutMillis = 10000) def impl(props: Map[String, String]): Api = { val urls = Array(new File(onlineJar()).toURI.toURL) @@ -748,7 +751,7 @@ object Driver { def run(args: Args): Unit = { val acceptedEndPoints = List(MetadataEndPoint.ConfByKeyEndPointName, MetadataEndPoint.NameByTeamEndPointName) - val dirWalker = new MetadataDirWalker(args.confPath(), acceptedEndPoints) + val dirWalker = new MetadataDirWalker(args.confPath(), acceptedEndPoints, maybeConfType = args.confType.toOption) val kvMap: Map[String, Map[String, List[String]]] = dirWalker.run // trigger creates of the datasets before we proceed with writes