diff --git a/.gitignore b/.gitignore index 9c7f4e3927..b0d1825439 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,8 @@ **/.DS_Store api/python/ai/chronon/api/ api/python/ai/chronon/observability/ +api/py/ai/chronon/api/ +api/py/ai/chronon/observability/ api/python/test/canary/production/ api/python/test/sample/production/ api/python/.coverage diff --git a/maven_install.json b/maven_install.json index f3f0c23952..83187e6b0e 100755 --- a/maven_install.json +++ b/maven_install.json @@ -1,7 +1,7 @@ { "__AUTOGENERATED_FILE_DO_NOT_MODIFY_THIS_FILE_MANUALLY": "THERE_IS_NO_DATA_ONLY_ZUUL", - "__INPUT_ARTIFACTS_HASH": -1078269646, - "__RESOLVED_ARTIFACTS_HASH": 254986144, + "__INPUT_ARTIFACTS_HASH": -1991826856, + "__RESOLVED_ARTIFACTS_HASH": 327110684, "artifacts": { "ant:ant": { "shasums": { @@ -587,10 +587,10 @@ }, "com.google.api.grpc:proto-google-cloud-pubsub-v1": { "shasums": { - "jar": "5cd9f8358c16577c735bcc478603c89a37b4c13e1bcf031262423fb99d79b509", - "sources": "406d9b9d9e70b7e407697c54463ba69afb304f43bf169d8e92d7876bcc8e8053" + "jar": "ec636b2e7b4908d8677e55326fddc228c6f9b1a4dd44ec5a4c193cf258887912", + "sources": "54c2c43a6d926eff4a27741323cce0ed7b6a7c402cf1a226f65edfcc897f1c4d" }, - "version": "1.113.0" + "version": "1.120.0" }, "com.google.api.grpc:proto-google-cloud-spanner-admin-database-v1": { "shasums": { @@ -629,10 +629,10 @@ }, "com.google.api.grpc:proto-google-common-protos": { "shasums": { - "jar": "61ac7fbd31a9f604890d22330a6f94b3f410ea2d7247e0f5f11a87ae34087385", - "sources": "736c912f7477663288f22e85fabe4c3c5fc05e9d4d0fd8362b94f62d59f9a377" + "jar": "2fcff25fe8a90fcacb146a900222c497ba0a9a531271e6b135a76450d23b1ef2", + "sources": "7d05a0c924f0101e5a4347bcc6b529b61af4a350c228aa9d1abe9f07e93bbdb7" }, - "version": "2.53.0" + "version": "2.54.1" }, "com.google.api.grpc:proto-google-iam-v1": { "shasums": { @@ -643,24 +643,24 @@ }, "com.google.api:api-common": { "shasums": { - "jar": "335933f1043d3b4022e301a7bba2a5614bbd59df88e6eb7e311d780669d55c20", - "sources": "48911d85a7145c42304c71ce9940f994a5987ef53317ac2bcae902b653e37f7b" + "jar": "8b11e1e1e42702cb80948e7ca62a9e06ddf82fe57a19cd68f9548eac80f39071", + "sources": "da573c313dbb0022602e9475d8077aeaf1dc603a3ae46569c0ee6e2d4f3e6d73" }, - "version": "2.45.0" + "version": "2.46.1" }, "com.google.api:gax": { "shasums": { - "jar": "0cc9de317cff3f67a260364dca1a72b720c940b525e533dd25a8b70e38b5f815", - "sources": "2c173d838ab5334d62554c866632a5057ff95f75ec60d6aebcfe4ee9cc6d2141" + "jar": "73a5d012fa89f8e589774ab51859602e0a6120b55eab049f903cb43f2d0feb74", + "sources": "ed55f66eb516c3608bb9863508a7299403a403755032295af987c93d72ae7297" }, - "version": "2.62.0" + "version": "2.60.0" }, "com.google.api:gax-grpc": { "shasums": { - "jar": "3a7f3a7966592fff66e2709cc8cec4c18be6ec073b43510d943bbbaf076b5e46", - "sources": "1861339fefc7591d8e7be47e2f5fca68477b7fe25680961d9b9c9781d71b7f4c" + "jar": "3ed87c6a43ad37c82e5e594c615e2f067606c45b977c97abfcfdd0bcc02ed852", + "sources": "790e0921e4b2f303e0003c177aa6ba11d3fe54ea33ae07c7b2f3bc8adec7d407" }, - "version": "2.62.0" + "version": "2.60.0" }, "com.google.api:gax-httpjson": { "shasums": { @@ -5440,12 +5440,16 @@ "org.checkerframework:checker-qual" ], "com.google.api.grpc:proto-google-cloud-pubsub-v1": [ + "com.google.api.grpc:proto-google-common-protos", + "com.google.api:api-common", "com.google.auto.value:auto-value-annotations", "com.google.code.findbugs:jsr305", "com.google.errorprone:error_prone_annotations", "com.google.guava:failureaccess", + "com.google.guava:guava", "com.google.guava:listenablefuture", "com.google.j2objc:j2objc-annotations", + "com.google.protobuf:protobuf-java", "javax.annotation:javax.annotation-api", "org.checkerframework:checker-qual" ], @@ -5522,7 +5526,6 @@ "com.google.auth:google-auth-library-oauth2-http", "com.google.guava:guava", "com.google.protobuf:protobuf-java", - "com.google.protobuf:protobuf-java-util", "io.opencensus:opencensus-api", "org.threeten:threetenbp" ], @@ -9670,7 +9673,6 @@ "com.google.api:gax": [ "com.google.api.gax.batching", "com.google.api.gax.core", - "com.google.api.gax.logging", "com.google.api.gax.longrunning", "com.google.api.gax.nativeimage", "com.google.api.gax.paging", diff --git a/orchestration/BUILD.bazel b/orchestration/BUILD.bazel index 5e0dda025b..a404a1bbf5 100644 --- a/orchestration/BUILD.bazel +++ b/orchestration/BUILD.bazel @@ -7,10 +7,10 @@ scala_library( }), visibility = ["//visibility:public"], deps = _VERTX_DEPS + [ - "//service_commons:lib", "//api:lib", "//api:thrift_java", "//online:lib", + "//service_commons:lib", maven_artifact_with_suffix("org.apache.logging.log4j:log4j-api-scala"), maven_artifact("org.apache.logging.log4j:log4j-core"), maven_artifact("org.apache.logging.log4j:log4j-api"), @@ -20,12 +20,18 @@ scala_library( maven_artifact("com.fasterxml.jackson.core:jackson-databind"), maven_artifact("com.google.protobuf:protobuf-java"), maven_artifact("com.google.code.findbugs:jsr305"), + maven_artifact("io.grpc:grpc-api"), maven_artifact("io.grpc:grpc-core"), maven_artifact("io.grpc:grpc-stub"), maven_artifact("io.grpc:grpc-inprocess"), maven_artifact("com.google.cloud:google-cloud-spanner"), + maven_artifact("com.google.cloud:google-cloud-pubsub"), + maven_artifact("com.google.api:gax"), + maven_artifact("com.google.api:gax-grpc"), + maven_artifact("com.google.api.grpc:proto-google-cloud-pubsub-v1"), maven_artifact("org.postgresql:postgresql"), maven_artifact_with_suffix("com.typesafe.slick:slick"), + maven_artifact("com.google.api:api-common"), ], ) @@ -44,6 +50,7 @@ test_deps = _VERTX_DEPS + _SCALA_TEST_DEPS + [ maven_artifact("io.temporal:temporal-serviceclient"), maven_artifact("com.fasterxml.jackson.core:jackson-core"), maven_artifact("com.fasterxml.jackson.core:jackson-databind"), + maven_artifact("io.grpc:grpc-api"), maven_artifact("io.grpc:grpc-core"), maven_artifact("io.grpc:grpc-stub"), maven_artifact("io.grpc:grpc-inprocess"), @@ -53,6 +60,11 @@ test_deps = _VERTX_DEPS + _SCALA_TEST_DEPS + [ maven_artifact("org.testcontainers:jdbc"), maven_artifact("org.testcontainers:testcontainers"), maven_artifact_with_suffix("com.typesafe.slick:slick"), + maven_artifact("com.google.cloud:google-cloud-pubsub"), + maven_artifact("com.google.api:gax"), + maven_artifact("com.google.api:gax-grpc"), + maven_artifact("com.google.api.grpc:proto-google-cloud-pubsub-v1"), + maven_artifact("com.google.api:api-common"), ] scala_library( @@ -73,6 +85,7 @@ scala_test_suite( # Excluding integration tests exclude = [ "src/test/**/NodeExecutionWorkflowIntegrationSpec.scala", + "src/test/**/GcpPubSubIntegrationSpec.scala", ], ), visibility = ["//visibility:public"], @@ -84,8 +97,12 @@ scala_test_suite( srcs = glob( [ "src/test/**/NodeExecutionWorkflowIntegrationSpec.scala", + "src/test/**/GcpPubSubIntegrationSpec.scala", ], ), + env = { + "PUBSUB_EMULATOR_HOST": "localhost:8085", + }, visibility = ["//visibility:public"], deps = test_deps + [":test_lib"], ) diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/persistence/NodeDao.scala b/orchestration/src/main/scala/ai/chronon/orchestration/persistence/NodeDao.scala index 15ee572e4c..ca81629a0e 100644 --- a/orchestration/src/main/scala/ai/chronon/orchestration/persistence/NodeDao.scala +++ b/orchestration/src/main/scala/ai/chronon/orchestration/persistence/NodeDao.scala @@ -50,9 +50,6 @@ class NodeRunDependencyTable(tag: Tag) extends Table[NodeRunDependency](tag, "No val parentRunId = column[String]("parent_run_id") val childRunId = column[String]("child_run_id") - // Composite primary key -// def pk = primaryKey("pk_node_run_dependency", (parentRunId, childRunId)) - def * = (parentRunId, childRunId).mapTo[NodeRunDependency] } @@ -63,9 +60,6 @@ class NodeRunAttemptTable(tag: Tag) extends Table[NodeRunAttempt](tag, "NodeRunA val endTime = column[Option[String]]("end_time") val status = column[String]("status") - // Composite primary key -// def pk = primaryKey("pk_node_run_attempt", (runId, attemptId)) - def * = (runId, attemptId, startTime, endTime, status).mapTo[NodeRunAttempt] } diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubAdmin.scala b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubAdmin.scala new file mode 100644 index 0000000000..c0a8a84968 --- /dev/null +++ b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubAdmin.scala @@ -0,0 +1,171 @@ +package ai.chronon.orchestration.pubsub + +import ai.chronon.orchestration.utils.GcpPubSubAdminUtils +import com.google.cloud.pubsub.v1.{SubscriptionAdminClient, TopicAdminClient} +import com.google.pubsub.v1.{PushConfig, SubscriptionName, TopicName} +import org.slf4j.LoggerFactory + +/** Administrative interface for managing Pub/Sub topics and subscriptions. + * + * This trait defines operations for creating and managing Pub/Sub resources, + * providing a clean abstraction over the underlying Pub/Sub implementation. + * It handles: + * + * - Topic creation and management + * - Subscription creation and management + * - Resource cleanup + * + * The interface is designed to be implementation-agnostic, allowing for + * different Pub/Sub backends to be used interchangeably. + */ +trait PubSubAdmin { + + /** Creates a topic in the Pub/Sub system if it doesn't already exist. + * + * This method attempts to create the topic and: + * - Succeeds if the topic is created successfully + * - Ignores the error if the topic already exists (making it idempotent) + * - Throws an exception for any other error + * + * @param topicId The unique identifier for the topic + * @throws Exception If there's an error creating the topic (other than 'already exists') + */ + def createTopic(topicId: String): Unit + + /** Creates a subscription to a topic if it doesn't already exist. + * + * This method attempts to create the subscription and: + * - Succeeds if the subscription is created successfully + * - Ignores the error if the subscription already exists (making it idempotent) + * - Throws an exception for any other error + * + * @param topicId The topic ID to subscribe to + * @param subscriptionId The unique identifier for the subscription + * @throws Exception If there's an error creating the subscription (other than 'already exists') + */ + def createSubscription(topicId: String, subscriptionId: String): Unit + + /** Releases resources and closes all admin clients. + * + * This method should be called when the admin is no longer needed to + * properly release resources and avoid leaks. + */ + def close(): Unit +} + +/** Google Cloud Pub/Sub implementation of the PubSubAdmin interface. + * + * This class provides Google Cloud-specific implementation of Pub/Sub administrative + * operations, using the Google Cloud Pub/Sub Admin API clients. + * + * This implementation uses a configuration that can be configured for either + * production use with real GCP credentials or local emulator use. + * + * @param config The Google Cloud Pub/Sub configuration to use + */ +class GcpPubSubAdmin(config: GcpPubSubConfig) extends PubSubAdmin { + private val logger = LoggerFactory.getLogger(getClass) + private val ackDeadlineSeconds = 10 + protected lazy val topicAdminClient: TopicAdminClient = GcpPubSubAdminUtils.createTopicAdminClient(config) + protected lazy val subscriptionAdminClient: SubscriptionAdminClient = + GcpPubSubAdminUtils.createSubscriptionAdminClient(config) + + override def createTopic(topicId: String): Unit = { + val topicName = TopicName.of(config.projectId, topicId) + + try { + topicAdminClient.createTopic(topicName) + logger.info(s"Created topic: ${topicName.toString}") + } catch { + case e: Exception => + // If the topic already exists, log it as info rather than error + if (e.getMessage != null && e.getMessage.contains("ALREADY_EXISTS")) { + logger.info(s"Topic $topicId already exists, skipping creation") + } else { + logger.error(s"Error creating topic ${topicName.toString}: ${e.getMessage}") + throw e + } + } + } + + override def createSubscription(topicId: String, subscriptionId: String): Unit = { + val topicName = TopicName.of(config.projectId, topicId) + val subscriptionName = SubscriptionName.of(config.projectId, subscriptionId) + + try { + subscriptionAdminClient.createSubscription( + subscriptionName, + topicName, + PushConfig.getDefaultInstance, // Pull subscription + ackDeadlineSeconds + ) + logger.info(s"Created subscription: ${subscriptionName.toString}") + } catch { + case e: Exception => + // If the subscription already exists, log it as info rather than error + if (e.getMessage != null && e.getMessage.contains("ALREADY_EXISTS")) { + logger.info(s"Subscription $subscriptionId already exists for $topicId, skipping creation") + } else { + logger.error(s"Error creating subscription ${subscriptionName.toString}: ${e.getMessage}") + throw e + } + } + } + + override def close(): Unit = { + try { + if (topicAdminClient != null) { + topicAdminClient.shutdown() + } + + if (subscriptionAdminClient != null) { + subscriptionAdminClient.shutdown() + } + + logger.info("PubSub admin clients shut down successfully") + } catch { + case e: Exception => + logger.error("Error shutting down PubSub admin clients", e) + } + } +} + +/** Factory for creating PubSubAdmin instances. + * + * This object provides factory methods to create different types of PubSubAdmin + * implementations with appropriate configurations. It simplifies the creation of + * admin instances for both production and local development environments. + * + * Usage examples: + * + * For production + * val admin = PubSubAdmin(GcpPubSubConfig.forProduction("my-project")) + * + * For local emulator + * val emulatorAdmin = PubSubAdmin.forEmulator("test-project", "localhost:8085") + */ +object PubSubAdmin { + + /** Creates a new Google Cloud PubSubAdmin instance with the provided configuration. + * + * @param config The Google Cloud Pub/Sub configuration + * @return A new PubSubAdmin instance + */ + def apply(config: GcpPubSubConfig): PubSubAdmin = { + new GcpPubSubAdmin(config) + } + + /** Creates a PubSubAdmin configured for use with the Pub/Sub emulator. + * + * This is a convenience method for local development and testing, + * as it automatically configures the necessary settings for emulator use. + * + * @param projectId The project ID to use with the emulator + * @param emulatorHost The emulator host:port (e.g., "localhost:8085") + * @return A PubSubAdmin configured for emulator use + */ + def forEmulator(projectId: String, emulatorHost: String): PubSubAdmin = { + val config = GcpPubSubConfig.forEmulator(projectId, emulatorHost) + apply(config) + } +} diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubConfig.scala b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubConfig.scala new file mode 100644 index 0000000000..7646994d24 --- /dev/null +++ b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubConfig.scala @@ -0,0 +1,109 @@ +package ai.chronon.orchestration.pubsub + +import com.google.api.gax.core.{CredentialsProvider, NoCredentialsProvider} +import com.google.api.gax.grpc.GrpcTransportChannel +import com.google.api.gax.rpc.{FixedTransportChannelProvider, TransportChannelProvider} +import io.grpc.ManagedChannelBuilder + +/** Base configuration trait for Pub/Sub clients. + * + * This trait defines the common interface for all Pub/Sub configurations, + * regardless of the underlying implementation. It provides a basis for + * configuration that can be shared across different Pub/Sub components + * (publishers, subscribers, admin clients). + * + * The trait requires a unique identifier for each configuration, which enables: + * - Caching of client instances with the same configuration + * - Distinguishing between different configurations in logs and metrics + * - Resource management based on configuration identity + */ +trait PubSubConfig { + + /** Returns a unique identifier for this configuration. + * + * This ID should uniquely identify the configuration settings, allowing + * components to recognize equivalent configurations. + * + * @return A string that uniquely identifies this configuration + */ + def id: String +} + +/** Configuration for Google Cloud Pub/Sub clients. + * + * This class contains all the necessary configuration parameters for connecting + * to Google Cloud Pub/Sub services, including: + * - Project identification + * - Transport channel configuration + * - Authentication credentials + * + * The configuration can be used for both production environments (with default + * credentials) and development/test environments (with emulator settings). + * + * @param projectId The Google Cloud project ID + * @param channelProvider Optional custom transport channel provider (for emulator or testing) + * @param credentialsProvider Optional custom credentials provider (for emulator or testing) + */ +case class GcpPubSubConfig( + projectId: String, + channelProvider: Option[TransportChannelProvider] = None, + credentialsProvider: Option[CredentialsProvider] = None +) extends PubSubConfig { + + /** Generates a unique identifier for this configuration. + * + * The ID combines the project ID with hashes of the channel and credentials + * providers, ensuring configurations with different settings have different IDs. + * + * @return A string that uniquely identifies this configuration + */ + override def id: String = s"${projectId}-${channelProvider.hashCode}-${credentialsProvider.hashCode}" +} + +/** Companion object for GcpPubSubConfig with factory methods for common configurations. + * + * This object provides factory methods to create GcpPubSubConfig instances + * preconfigured for common scenarios: + * - Production use with default GCP credentials + * - Local development with the Pub/Sub emulator + */ +object GcpPubSubConfig { + + /** Creates a configuration for production Google Cloud environments. + * + * This configuration uses the default GCP credentials available in the + * environment (e.g., from service account files, application default credentials, etc.). + * It's suitable for use in production Google Cloud environments like GCE, GKE, or Cloud Run. + * + * @param projectId The Google Cloud project ID + * @return A configuration for production use + */ + def forProduction(projectId: String): GcpPubSubConfig = { + GcpPubSubConfig(projectId) + } + + /** Creates a configuration for the local Pub/Sub emulator. + * + * This configuration sets up the necessary transport channel and credentials + * to connect to a local Pub/Sub emulator. It's useful for development and testing + * without needing actual GCP resources or credentials. + * + * @param projectId The project ID to use with the emulator (can be any string) + * @param emulatorHost The emulator host:port address (default: localhost:8085) + * @return Configuration optimized for emulator use + */ + def forEmulator(projectId: String, emulatorHost: String = "localhost:8085"): GcpPubSubConfig = { + // Create channel for emulator with plaintext (non-TLS) communication + val channel = ManagedChannelBuilder.forTarget(emulatorHost).usePlaintext().build() + val channelProvider = FixedTransportChannelProvider.create(GrpcTransportChannel.create(channel)) + + // No credentials needed for emulator + val credentialsProvider = NoCredentialsProvider.create() + + GcpPubSubConfig( + projectId = projectId, + channelProvider = Some(channelProvider), + credentialsProvider = Some(credentialsProvider) + ) + } +} diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubManager.scala b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubManager.scala new file mode 100644 index 0000000000..55000cc77b --- /dev/null +++ b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubManager.scala @@ -0,0 +1,188 @@ +package ai.chronon.orchestration.pubsub + +import org.slf4j.LoggerFactory + +import scala.collection.concurrent.TrieMap + +/** Central manager for Pub/Sub publishers and subscribers. + * + * This trait defines a manager that coordinates Pub/Sub components, providing + * a single point of access for creating and retrieving publishers and subscribers. + * It serves as a factory and registry for Pub/Sub components, ensuring: + * + * - Consistent configuration across components + * - Resource reuse (clients are cached and shared) + * - Proper resource lifecycle management + * - Automatic topic and subscription creation + * + * Using this manager simplifies Pub/Sub operations by handling the infrastructure + * details, allowing application code to focus on business logic. + */ +trait PubSubManager { + + /** Gets or creates a publisher for a specific topic. + * + * This method ensures that a topic exists before returning a publisher, + * creating it if necessary. It also caches publishers to avoid creating + * multiple instances for the same topic. + * + * @param topicId The unique identifier for the topic + * @return A publisher configured for the specified topic + */ + def getOrCreatePublisher(topicId: String): PubSubPublisher + + /** Gets or creates a subscriber for a specific subscription. + * + * This method ensures that both the topic and subscription exist before + * returning a subscriber, creating them if necessary. It also caches + * subscribers to avoid creating multiple instances for the same subscription. + * + * @param topicId The topic ID to subscribe to + * @param subscriptionId The unique identifier for the subscription + * @return A subscriber configured for the specified subscription + */ + def getOrCreateSubscriber(topicId: String, subscriptionId: String): PubSubSubscriber + + /** Releases all resources managed by this manager. + * + * This method should be called when the manager is no longer needed to + * properly close all publishers, subscribers, and other resources. + * It ensures clean shutdown of connections and prevents resource leaks. + */ + def shutdown(): Unit +} + +/** Google Cloud implementation of the PubSubManager interface. + * + * The implementation uses thread-safe caching to ensure efficient resource use + * while maintaining thread safety for concurrent operations. + * + * @param config The Google Cloud Pub/Sub configuration to use + */ +class GcpPubSubManager(config: GcpPubSubConfig) extends PubSubManager { + private val logger = LoggerFactory.getLogger(getClass) + protected val admin: PubSubAdmin = PubSubAdmin(config) + + // Cache of publishers by topic ID + private val publishers = TrieMap.empty[String, PubSubPublisher] + + // Cache of subscribers by subscription ID + private val subscribers = TrieMap.empty[String, PubSubSubscriber] + + override def getOrCreatePublisher(topicId: String): PubSubPublisher = { + publishers.getOrElseUpdate(topicId, { + // Create the topic if it doesn't exist + admin.createTopic(topicId) + + // Create a new publisher + PubSubPublisher(config, topicId) + }) + } + + override def getOrCreateSubscriber(topicId: String, subscriptionId: String): PubSubSubscriber = { + subscribers.getOrElseUpdate( + subscriptionId, { + // Create the subscription if it doesn't exist + admin.createSubscription(topicId, subscriptionId) + + // Create a new subscriber + PubSubSubscriber(config, subscriptionId) + } + ) + } + + override def shutdown(): Unit = { + try { + // Shutdown all publishers + publishers.values.foreach { publisher => + try { + publisher.shutdown() + } catch { + case e: Exception => + logger.error(s"Error shutting down publisher: ${e.getMessage}") + } + } + + // Shutdown all subscribers + subscribers.values.foreach { subscriber => + try { + subscriber.shutdown() + } catch { + case e: Exception => + logger.error(s"Error shutting down subscriber: ${e.getMessage}") + } + } + + // Close the admin client + admin.close() + + // Clear the caches + publishers.clear() + subscribers.clear() + + logger.info("PubSub manager shut down successfully") + } catch { + case e: Exception => + logger.error("Error shutting down PubSub manager", e) + } + } +} + +/** Factory for creating and managing PubSubManager instances. + * + * This object provides factory methods for creating and retrieving PubSubManager + * instances. It uses a cache to ensure that only one manager instance is created + * for each unique configuration, promoting efficient resource usage. + */ +object PubSubManager { + // Thread-safe cache of managers by configuration ID + private val managers = TrieMap.empty[String, PubSubManager] + + /** Gets or creates a Google Cloud PubSubManager for a specific configuration. + * + * @param config The Google Cloud Pub/Sub configuration + * @return A PubSubManager instance for the given configuration + */ + def apply(config: GcpPubSubConfig): PubSubManager = { + managers.getOrElseUpdate(config.id, new GcpPubSubManager(config)) + } + + /** Creates a manager configured for production Google Cloud environments. + * + * This convenience method creates a manager with production settings, + * using default Google Cloud credentials in the specified project. + * + * @param projectId The Google Cloud project ID + * @return A manager configured for production use + */ + def forProduction(projectId: String): PubSubManager = { + val config = GcpPubSubConfig.forProduction(projectId) + apply(config) + } + + /** Creates a manager configured for the local Pub/Sub emulator. + * + * This convenience method creates a manager for development and testing + * with a local Pub/Sub emulator, automatically configuring the connection + * properties. + * + * @param projectId The project ID to use with the emulator + * @param emulatorHost The emulator host:port address + * @return A manager configured for emulator use + */ + def forEmulator(projectId: String, emulatorHost: String): PubSubManager = { + val config = GcpPubSubConfig.forEmulator(projectId, emulatorHost) + apply(config) + } + + /** Shuts down all manager instances and clears the cache. + * + * This method provides a convenient way to clean up all Pub/Sub resources + * when the application is shutting down. It should be called before the + * application exits to ensure proper resource cleanup. + */ + def shutdownAll(): Unit = { + managers.values.foreach(_.shutdown()) + managers.clear() + } +} diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubMessage.scala b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubMessage.scala new file mode 100644 index 0000000000..0e7a278400 --- /dev/null +++ b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubMessage.scala @@ -0,0 +1,147 @@ +package ai.chronon.orchestration.pubsub + +import ai.chronon.orchestration.DummyNode +import com.google.protobuf.ByteString +import com.google.pubsub.v1.PubsubMessage + +/** Base message interface for Pub/Sub messages. + * + * This trait defines a common interface for messages across different Pub/Sub + * implementations, providing a platform-agnostic way to work with message + * data and attributes. + * + * The interface separates: + * - Message metadata (attributes/properties) + * - Message payload (data/body) + * + * This abstraction allows the system to: + * - Work with different message formats consistently + * - Hide implementation-specific details from business logic + * - Support multiple Pub/Sub providers with a unified interface + * - Test message handling without depending on actual Pub/Sub implementations + */ +trait PubSubMessage { + + /** Gets the message attributes/properties as key-value pairs. + * + * Message attributes are metadata associated with the message that provide + * additional context or routing information. + * + * @return A map of attribute names to their string values + */ + def getAttributes: Map[String, String] + + /** Gets the message payload/body data. + * + * The message data contains the actual content being delivered. It may be + * empty for messages that only use attributes for signaling. + * + * @return The binary message data, if present + */ + def getData: Option[Array[Byte]] +} + +/** Google Cloud-specific extension of the PubSubMessage interface. + */ +trait GcpPubSubMessage extends PubSubMessage { + + /** Converts this message to the Google Cloud Pub/Sub native format. + * + * This method transforms the abstract message representation into the + * concrete Google Pub/Sub format required by the Google Cloud libraries. + * + * @return A Google Cloud PubsubMessage ready for publishing + */ + def toPubsubMessage: PubsubMessage +} + +/** Message implementation for job submission requests. + * + * This class represents a message that triggers the execution of a node + * in the computation graph. It implements the Google Cloud Pub/Sub message + * interface, allowing it to be published to Google Cloud Pub/Sub topics. + * + * @param nodeName The name of the node to execute + * @param data Optional message body as a string + * @param attributes Additional key-value metadata for the message + */ +case class JobSubmissionMessage( + nodeName: String, + data: Option[String] = None, + attributes: Map[String, String] = Map.empty +) extends GcpPubSubMessage { + + /** Gets the combined message attributes including the node name. + * + * This implementation ensures the node name is always included in the attributes + * map, even if it wasn't explicitly provided in the constructor. + * + * @return Map containing the node name and any additional attributes + */ + override def getAttributes: Map[String, String] = { + attributes + ("nodeName" -> nodeName) + } + + /** Gets the message data as a UTF-8 encoded byte array. + * + * Converts the optional string data to a byte array using UTF-8 encoding. + * + * @return The message data as bytes, if present + */ + override def getData: Option[Array[Byte]] = { + data.map(_.getBytes("UTF-8")) + } + + /** Converts this message to a native Google Cloud Pub/Sub message. + * + * This method builds a Google Pub/Sub message with: + * - The node name as a required attribute + * - Any additional attributes provided to this message + * - Optional message data encoded as UTF-8 + * + * @return A properly formatted Google Cloud PubsubMessage + */ + override def toPubsubMessage: PubsubMessage = { + val builder = PubsubMessage + .newBuilder() + .putAttributes("nodeName", nodeName) + + // Add additional attributes + attributes.foreach { case (key, value) => + builder.putAttributes(key, value) + } + + // Add message data if provided + data.foreach { d => + builder.setData(ByteString.copyFromUtf8(d)) + } + + builder.build() + } +} + +/** Factory methods for creating JobSubmissionMessage instances. + * + * This companion object provides convenient factory methods for creating + * JobSubmissionMessage instances from different sources. + * + * TODO: The conversion from DummyNode is temporary and will be replaced + * with more appropriate conversion methods in the future. + */ +object JobSubmissionMessage { + + /** Creates a JobSubmissionMessage from a DummyNode. + * + * This is a temporary method for backward compatibility. + * + * @param node The DummyNode to create a message for + * @return A JobSubmissionMessage for the node + * @deprecated Use fromNodeName instead + */ + def fromDummyNode(node: DummyNode): JobSubmissionMessage = { + JobSubmissionMessage( + nodeName = node.name, + data = Some(s"Job submission for node: ${node.name}") + ) + } +} diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubPublisher.scala b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubPublisher.scala new file mode 100644 index 0000000000..a6245e9d6e --- /dev/null +++ b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubPublisher.scala @@ -0,0 +1,176 @@ +package ai.chronon.orchestration.pubsub + +import com.google.api.core.{ApiFutureCallback, ApiFutures} +import com.google.cloud.pubsub.v1.Publisher +import com.google.pubsub.v1.TopicName +import org.slf4j.LoggerFactory + +import java.util.concurrent.{CompletableFuture, Executors} +import scala.util.{Failure, Success, Try} + +/** Publisher interface for sending messages to a Pub/Sub system. + * + * This trait defines the core functionality for publishing messages to + * a Pub/Sub topic, providing a clean abstraction over the underlying + * Pub/Sub implementation. + */ +trait PubSubPublisher { + + /** Gets the topic ID this publisher publishes to. + * + * The topic ID uniquely identifies the destination for messages + * published through this publisher. + * + * @return The unique identifier for the topic + */ + def topicId: String + + /** Publishes a message to the Pub/Sub topic. + * + * This method asynchronously publishes a message to the topic and returns + * a future that completes when the message is successfully published or + * fails if there's an error. + * + * @param message The message to publish + * @return A CompletableFuture that completes with the message ID when published successfully, + * or completes exceptionally if there's an error + */ + def publish(message: PubSubMessage): CompletableFuture[String] + + /** Releases resources and shuts down the publisher. + * + * This method should be called when the publisher is no longer needed to + * properly release resources and avoid leaks. + */ + def shutdown(): Unit +} + +/** Google Cloud implementation of the PubSubPublisher interface. + * + * @param config The Google Cloud Pub/Sub configuration + * @param topicId The ID of the topic to publish to + */ +class GcpPubSubPublisher( + val config: GcpPubSubConfig, + val topicId: String +) extends PubSubPublisher { + private val logger = LoggerFactory.getLogger(getClass) + private val executor = Executors.newSingleThreadExecutor() + private lazy val publisher = createPublisher() + + protected def createPublisher(): Publisher = { + val topicName = TopicName.of(config.projectId, topicId) + logger.info(s"Creating publisher for topic: $topicName") + + // Start with the basic builder + val builder = Publisher.newBuilder(topicName) + + // Add channel provider if specified + config.channelProvider.foreach { provider => + logger.info("Using custom channel provider for Publisher") + builder.setChannelProvider(provider) + } + + // Add credentials provider if specified + config.credentialsProvider.foreach { provider => + logger.info("Using custom credentials provider for Publisher") + builder.setCredentialsProvider(provider) + } + + // Build the publisher + builder.build() + } + + override def publish(message: PubSubMessage): CompletableFuture[String] = { + val result = new CompletableFuture[String]() + + message match { + case gcpMessage: GcpPubSubMessage => + Try { + // Convert to Google PubSub message format + val pubsubMessage = gcpMessage.toPubsubMessage + + // Publish the message + val messageIdFuture = publisher.publish(pubsubMessage) + + // Add a callback to handle success/failure + ApiFutures.addCallback( + messageIdFuture, + new ApiFutureCallback[String] { + override def onFailure(t: Throwable): Unit = { + logger.error(s"Failed to publish message to $topicId", t) + result.completeExceptionally(t) + } + + override def onSuccess(messageId: String): Unit = { + logger.info(s"Published message with ID: $messageId to $topicId") + result.complete(messageId) + } + }, + executor + ) + } match { + case Success(_) => // Callback will handle completion + case Failure(e) => + logger.error(s"Error setting up message publishing to $topicId", e) + result.completeExceptionally(e) + } + case _ => + val error = new IllegalArgumentException( + s"Message type ${message.getClass.getName} is not supported for GCP PubSub. Expected GcpPubSubMessage.") + logger.error(error.getMessage) + result.completeExceptionally(error) + } + + result + } + + override def shutdown(): Unit = { + Try { + if (publisher != null) { + publisher.shutdown() + } + + executor.shutdown() + + logger.info(s"Publisher for topic $topicId shut down successfully") + } match { + case Success(_) => // Shutdown successful + case Failure(e) => logger.error(s"Error shutting down publisher for topic $topicId", e) + } + } +} + +/** Factory for creating PubSubPublisher instances. + */ +object PubSubPublisher { + + /** Creates a publisher for Google Cloud PubSub. + * + * This factory method creates a new PubSubPublisher instance configured + * for the Google Cloud Pub/Sub service using the provided configuration. + * + * @param config The Google Cloud Pub/Sub configuration + * @param topicId The ID of the topic to publish to + * @return A configured PubSubPublisher instance + */ + def apply(config: GcpPubSubConfig, topicId: String): PubSubPublisher = { + new GcpPubSubPublisher(config, topicId) + } + + /** Creates a publisher for the Pub/Sub emulator. + * + * This convenience method creates a publisher configured to work with + * a local Pub/Sub emulator, which is useful for development and testing + * without requiring actual Google Cloud resources. + * + * @param projectId The Google Cloud project ID (can be any value for emulator) + * @param topicId The ID of the topic to publish to + * @param emulatorHost The host and port of the Pub/Sub emulator (e.g., "localhost:8085") + * @return A configured PubSubPublisher instance for the emulator + */ + def forEmulator(projectId: String, topicId: String, emulatorHost: String): PubSubPublisher = { + val config = GcpPubSubConfig.forEmulator(projectId, emulatorHost) + apply(config, topicId) + } +} diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubSubscriber.scala b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubSubscriber.scala new file mode 100644 index 0000000000..cc9a689942 --- /dev/null +++ b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/PubSubSubscriber.scala @@ -0,0 +1,149 @@ +package ai.chronon.orchestration.pubsub + +import ai.chronon.api.ScalaJavaConversions._ +import ai.chronon.orchestration.utils.GcpPubSubAdminUtils +import com.google.cloud.pubsub.v1.SubscriptionAdminClient +import com.google.pubsub.v1.{PubsubMessage, SubscriptionName} +import org.slf4j.LoggerFactory + +/** Generic subscriber interface for receiving messages from a Pub/Sub system. + * + * This trait defines the core functionality for subscribing to and receiving + * messages from a Pub/Sub system, providing a clean abstraction over the + * underlying Pub/Sub implementation. + */ +trait PubSubSubscriber { + private val batchSize = 10 + + /** Gets the subscription ID this subscriber listens to. + */ + def subscriptionId: String + + /** Pulls messages from the subscription. + * + * This method pulls a batch of messages from the subscription and automatically + * acknowledges them upon successful retrieval. The default batch size is 10 + * messages, but this can be configured through the maxMessages parameter. + * + * @param maxMessages Maximum number of messages to pull in a single batch + * @return A sequence of received messages + * @throws RuntimeException if there's an error communicating with the subscription + */ + def pullMessages(maxMessages: Int = batchSize): Seq[PubSubMessage] + + /** Releases resources and shuts down the subscriber. + * + * This method should be called when the subscriber is no longer needed to + * properly release resources and avoid leaks. + */ + def shutdown(): Unit +} + +/** Google Cloud implementation of the PubSubSubscriber interface. + * + * @param config The Google Cloud Pub/Sub configuration + * @param subscriptionId The ID of the subscription to pull messages from + */ +class GcpPubSubSubscriber( + config: GcpPubSubConfig, + val subscriptionId: String +) extends PubSubSubscriber { + private val logger = LoggerFactory.getLogger(getClass) + protected val adminClient: SubscriptionAdminClient = GcpPubSubAdminUtils.createSubscriptionAdminClient(config) + + override def pullMessages(maxMessages: Int): Seq[PubSubMessage] = { + val subscriptionName = SubscriptionName.of(config.projectId, subscriptionId) + + try { + val response = adminClient.pull(subscriptionName, maxMessages) + + val receivedMessages = response.getReceivedMessagesList.toScala + + // Convert to GCP-specific messages + val messages = receivedMessages + .map(received => { + val pubsubMessage = received.getMessage + + // Convert to our abstraction with special wrapper for GCP messages + new GcpPubSubMessageWrapper(pubsubMessage) + }) + + // Acknowledge the messages + if (messages.nonEmpty) { + try { + val ackIds = receivedMessages + .map(received => received.getAckId) + + adminClient.acknowledge(subscriptionName, ackIds.toJava) + } catch { + case e: Exception => + // Log the acknowledgment error but still return the messages + logger.warn(s"Error acknowledging messages from $subscriptionId: ${e.getMessage}") + } + } + + messages + } catch { + // TODO: To add proper error handling based on other potential scenarios + case e: Exception => + val errorMsg = s"Error pulling messages from $subscriptionId: ${e.getMessage}" + logger.error(errorMsg) + throw new RuntimeException(errorMsg, e) + } + } + + /** Releases resources and shuts down the subscriber. + * + * This method closes the Google Cloud Subscription Admin Client and + * releases all associated resources to prevent leaks. + */ + override def shutdown(): Unit = { + // Close the admin client + if (adminClient != null) { + adminClient.close() + } + logger.info(s"Subscriber for subscription $subscriptionId shut down successfully") + } +} + +/** Wrapper for Google Cloud PubSub messages that implements our PubSubMessage abstraction. + * + * This class wraps a Google Cloud PubsubMessage to make it compatible with + * our generic PubSubMessage interface. + * + * @param message The Google Cloud PubsubMessage to wrap + */ +class GcpPubSubMessageWrapper(val message: PubsubMessage) extends GcpPubSubMessage { + + override def getAttributes: Map[String, String] = { + message.getAttributesMap.toScala.toMap + } + + override def getData: Option[Array[Byte]] = { + if (message.getData.isEmpty) None + else Some(message.getData.toByteArray) + } + + override def toPubsubMessage: PubsubMessage = message +} + +/** Factory for creating PubSubSubscriber instances. + */ +object PubSubSubscriber { + + /** Creates a subscriber for Google Cloud PubSub. + * + * This factory method creates a new PubSubSubscriber instance configured + * for the Google Cloud Pub/Sub service using the provided configuration. + * + * @param config The Google Cloud Pub/Sub configuration + * @param subscriptionId The ID of the subscription to pull messages from + * @return A configured PubSubSubscriber instance + */ + def apply( + config: GcpPubSubConfig, + subscriptionId: String + ): PubSubSubscriber = { + new GcpPubSubSubscriber(config, subscriptionId) + } +} diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/README.md b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/README.md new file mode 100644 index 0000000000..198e127a02 --- /dev/null +++ b/orchestration/src/main/scala/ai/chronon/orchestration/pubsub/README.md @@ -0,0 +1,92 @@ +# Chronon PubSub Module + +This module provides a flexible, modular, and lightweight abstraction for working with Google Cloud Pub/Sub. + +## Components + +The PubSub module is organized into several components to separate concerns and promote flexibility: + +### 1. Messages (`PubSubMessage.scala`) + +- `PubSubMessage` - Base trait for all messages that can be published to PubSub +- `JobSubmissionMessage` - Implementation for job submission messages + +### 2. Configuration (`PubSubConfig.scala`) + +- `PubSubConfig` - Configuration for PubSub connections +- Helper methods for creating production and emulator configurations + +### 3. Admin (`PubSubAdmin.scala`) + +- `PubSubAdmin` - Interface for managing topics and subscriptions +- `GcpPubSubAdmin` - Implementation for Google Cloud Pub/Sub + +### 4. Publisher (`PubSubPublisher.scala`) + +- `PubSubPublisher` - Interface for publishing messages +- `GcpPubSubPublisher` - Implementation for Google Cloud Pub/Sub + +### 5. Subscriber (`PubSubSubscriber.scala`) + +- `PubSubSubscriber` - Interface for receiving messages +- `GcpPubSubSubscriber` - Implementation for Google Cloud Pub/Sub + +### 6. Manager (`PubSubManager.scala`) + +- `PubSubManager` - Manages PubSub components and provides caching +- Factory methods for creating configured managers + +## Usage Examples + +### Basic Production Usage + +```scala +// Create a manager for production +val manager = PubSubManager.forProduction("my-project-id") + +// Get a publisher +val publisher = manager.getOrCreatePublisher("my-topic") + +// Create and publish a message +val message = JobSubmissionMessage("my-node", Some("Job data")) +val future = publisher.publish(message) + +// Get a subscriber +val subscriber = manager.getOrCreateSubscriber("my-topic", "my-subscription") + +// Pull messages +val messages = subscriber.pullMessages(10) + +// Remember to shutdown when done +manager.shutdown() +``` + +### Testing with Emulator + +```scala +// Create a manager for the emulator +val manager = PubSubManager.forEmulator("test-project", "localhost:8085") + +// Now use it the same way as production +val publisher = manager.getOrCreatePublisher("test-topic") +val subscriber = manager.getOrCreateSubscriber("test-topic", "test-subscription") +``` + +### Integration with NodeExecutionActivity + +```scala +// Create a publisher for the activity +val publisher = PubSubManager.forProduction("my-project-id") + .getOrCreatePublisher("job-submissions") + +// Create the activity with the publisher +val activity = NodeExecutionActivityFactory.create(workflowClient, publisher) +``` + +## Benefits + +1. **Separation of Concerns** - Each component has a single responsibility +2. **Dependency Injection** - Easy to inject and mock for testing +3. **Caching** - Publishers and subscribers are cached for efficiency +4. **Resource Management** - Clean shutdown of all resources +5. **Emulator Support** - Seamless support for local testing with emulator \ No newline at end of file diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/temporal/activity/NodeExecutionActivity.scala b/orchestration/src/main/scala/ai/chronon/orchestration/temporal/activity/NodeExecutionActivity.scala index 2853de9bb4..0eb6506307 100644 --- a/orchestration/src/main/scala/ai/chronon/orchestration/temporal/activity/NodeExecutionActivity.scala +++ b/orchestration/src/main/scala/ai/chronon/orchestration/temporal/activity/NodeExecutionActivity.scala @@ -1,8 +1,10 @@ package ai.chronon.orchestration.temporal.activity import ai.chronon.orchestration.DummyNode +import ai.chronon.orchestration.pubsub.{JobSubmissionMessage, PubSubPublisher} import ai.chronon.orchestration.temporal.workflow.WorkflowOperations import io.temporal.activity.{Activity, ActivityInterface, ActivityMethod} +import org.slf4j.LoggerFactory /** Defines helper activity methods that are needed for node execution workflow */ @@ -22,10 +24,14 @@ import io.temporal.activity.{Activity, ActivityInterface, ActivityMethod} /** Dependency injection through constructor is supported for activities but not for workflows * https://community.temporal.io/t/complex-workflow-dependencies/511 */ -class NodeExecutionActivityImpl(workflowOps: WorkflowOperations) extends NodeExecutionActivity { +class NodeExecutionActivityImpl( + workflowOps: WorkflowOperations, + pubSubPublisher: PubSubPublisher +) extends NodeExecutionActivity { - override def triggerDependency(dependency: DummyNode): Unit = { + private val logger = LoggerFactory.getLogger(getClass) + override def triggerDependency(dependency: DummyNode): Unit = { val context = Activity.getExecutionContext context.doNotCompleteOnReturn() @@ -46,6 +52,27 @@ class NodeExecutionActivityImpl(workflowOps: WorkflowOperations) extends NodeExe } override def submitJob(node: DummyNode): Unit = { - // TODO: Actual Implementation for job submission + logger.info(s"Submitting job for node: ${node.name}") + + val context = Activity.getExecutionContext + context.doNotCompleteOnReturn() + + val completionClient = context.useLocalManualCompletion() + + // Create a message from the node + val message = JobSubmissionMessage.fromDummyNode(node) + + // Publish the message + val future = pubSubPublisher.publish(message) + + future.whenComplete((messageId, error) => { + if (error != null) { + logger.error(s"Failed to submit job for node: ${node.name}", error) + completionClient.fail(error) + } else { + logger.info(s"Successfully submitted job for node: ${node.name} with messageId: $messageId") + completionClient.complete(messageId) + } + }) } } diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/temporal/activity/NodeExecutionActivityFactory.scala b/orchestration/src/main/scala/ai/chronon/orchestration/temporal/activity/NodeExecutionActivityFactory.scala index c5dc0f0e5e..f58de1ee7c 100644 --- a/orchestration/src/main/scala/ai/chronon/orchestration/temporal/activity/NodeExecutionActivityFactory.scala +++ b/orchestration/src/main/scala/ai/chronon/orchestration/temporal/activity/NodeExecutionActivityFactory.scala @@ -1,12 +1,88 @@ package ai.chronon.orchestration.temporal.activity +import ai.chronon.orchestration.pubsub.{GcpPubSubConfig, PubSubManager, PubSubPublisher} import ai.chronon.orchestration.temporal.workflow.WorkflowOperationsImpl import io.temporal.client.WorkflowClient // Factory for creating activity implementations object NodeExecutionActivityFactory { + + /** Create a NodeExecutionActivity with default configuration from environment variables. + * + * This method relies on environment variables for configuration: + * - GCP_PROJECT_ID: The Google Cloud project ID (required) + * - PUBSUB_TOPIC_ID: The PubSub topic for job submissions (required) + * + * @param workflowClient The Temporal workflow client + * @return A NodeExecutionActivity configured from environment variables + * @throws IllegalArgumentException if required environment variables are not set + */ def create(workflowClient: WorkflowClient): NodeExecutionActivity = { + // Get environment variables with validation + val projectId = sys.env.getOrElse( + "GCP_PROJECT_ID", + throw new IllegalArgumentException("Environment variable GCP_PROJECT_ID must be set")) + + val topicId = sys.env.getOrElse( + "PUBSUB_TOPIC_ID", + throw new IllegalArgumentException("Environment variable PUBSUB_TOPIC_ID must be set")) + + // Verify that they're not empty + if (projectId.trim.isEmpty) { + throw new IllegalArgumentException("Environment variable GCP_PROJECT_ID cannot be empty") + } + + if (topicId.trim.isEmpty) { + throw new IllegalArgumentException("Environment variable PUBSUB_TOPIC_ID cannot be empty") + } + + create(workflowClient, projectId, topicId) + } + + /** Create a NodeExecutionActivity with custom PubSub manager + */ + def create( + workflowClient: WorkflowClient, + pubSubManager: PubSubManager, + topicId: String + ): NodeExecutionActivity = { + val publisher = pubSubManager.getOrCreatePublisher(topicId) + + val workflowOps = new WorkflowOperationsImpl(workflowClient) + new NodeExecutionActivityImpl(workflowOps, publisher) + } + + /** Create a NodeExecutionActivity with explicit configuration + */ + def create(workflowClient: WorkflowClient, projectId: String, topicId: String): NodeExecutionActivity = { + // Create PubSub configuration based on environment + val manager = sys.env.get("PUBSUB_EMULATOR_HOST") match { + case Some(emulatorHost) => + // Use emulator configuration if PUBSUB_EMULATOR_HOST is set + PubSubManager.forEmulator(projectId, emulatorHost) + case None => + // Use default configuration for production + PubSubManager.forProduction(projectId) + } + + create(workflowClient, manager, topicId) + } + + /** Create a NodeExecutionActivity with custom PubSub configuration + */ + def create( + workflowClient: WorkflowClient, + config: GcpPubSubConfig, + topicId: String + ): NodeExecutionActivity = { + val manager = PubSubManager(config) + create(workflowClient, manager, topicId) + } + + /** Create a NodeExecutionActivity with a pre-configured PubSub publisher + */ + def create(workflowClient: WorkflowClient, pubSubPublisher: PubSubPublisher): NodeExecutionActivity = { val workflowOps = new WorkflowOperationsImpl(workflowClient) - new NodeExecutionActivityImpl(workflowOps) + new NodeExecutionActivityImpl(workflowOps, pubSubPublisher) } } diff --git a/orchestration/src/main/scala/ai/chronon/orchestration/utils/GcpPubSubAdminUtils.scala b/orchestration/src/main/scala/ai/chronon/orchestration/utils/GcpPubSubAdminUtils.scala new file mode 100644 index 0000000000..a364f505b4 --- /dev/null +++ b/orchestration/src/main/scala/ai/chronon/orchestration/utils/GcpPubSubAdminUtils.scala @@ -0,0 +1,60 @@ +package ai.chronon.orchestration.utils + +import ai.chronon.orchestration.pubsub.GcpPubSubConfig +import com.google.cloud.pubsub.v1.{ + SubscriptionAdminClient, + SubscriptionAdminSettings, + TopicAdminClient, + TopicAdminSettings +} +import org.slf4j.LoggerFactory + +/** Utility class for creating GCP PubSub admin clients + */ +object GcpPubSubAdminUtils { + private val logger = LoggerFactory.getLogger(getClass) + + /** Create a topic admin client for Google Cloud PubSub + * @param config The GCP PubSub configuration + * @return A TopicAdminClient configured with the provided settings + */ + def createTopicAdminClient(config: GcpPubSubConfig): TopicAdminClient = { + val topicAdminSettingsBuilder = TopicAdminSettings.newBuilder() + + // Add channel provider if specified + config.channelProvider.foreach { provider => + logger.info("Using custom channel provider for TopicAdminClient") + topicAdminSettingsBuilder.setTransportChannelProvider(provider) + } + + // Add credentials provider if specified + config.credentialsProvider.foreach { provider => + logger.info("Using custom credentials provider for TopicAdminClient") + topicAdminSettingsBuilder.setCredentialsProvider(provider) + } + + TopicAdminClient.create(topicAdminSettingsBuilder.build()) + } + + /** Create a subscription admin client for Google Cloud PubSub + * @param config The GCP PubSub configuration + * @return A SubscriptionAdminClient configured with the provided settings + */ + def createSubscriptionAdminClient(config: GcpPubSubConfig): SubscriptionAdminClient = { + val subscriptionAdminSettingsBuilder = SubscriptionAdminSettings.newBuilder() + + // Add channel provider if specified + config.channelProvider.foreach { provider => + logger.info("Using custom channel provider for SubscriptionAdminClient") + subscriptionAdminSettingsBuilder.setTransportChannelProvider(provider) + } + + // Add credentials provider if specified + config.credentialsProvider.foreach { provider => + logger.info("Using custom credentials provider for SubscriptionAdminClient") + subscriptionAdminSettingsBuilder.setCredentialsProvider(provider) + } + + SubscriptionAdminClient.create(subscriptionAdminSettingsBuilder.build()) + } +} diff --git a/orchestration/src/test/scala/ai/chronon/orchestration/test/pubsub/GcpPubSubIntegrationSpec.scala b/orchestration/src/test/scala/ai/chronon/orchestration/test/pubsub/GcpPubSubIntegrationSpec.scala new file mode 100644 index 0000000000..2017facc18 --- /dev/null +++ b/orchestration/src/test/scala/ai/chronon/orchestration/test/pubsub/GcpPubSubIntegrationSpec.scala @@ -0,0 +1,212 @@ +package ai.chronon.orchestration.test.pubsub + +import ai.chronon.orchestration.DummyNode +import ai.chronon.orchestration.pubsub._ +import org.scalatest.BeforeAndAfterAll +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.util.UUID +import java.util.concurrent.TimeUnit +import scala.util.Try + +/** Integration tests for PubSub components with the emulator. + * + * Prerequisites: + * - PubSub emulator must be running + * - PUBSUB_EMULATOR_HOST environment variable must be set (e.g., localhost:8085) + */ +class GcpPubSubIntegrationSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { + + // Test configuration + private val emulatorHost = sys.env.getOrElse("PUBSUB_EMULATOR_HOST", "localhost:8085") + private val projectId = "test-project" + private val testId = UUID.randomUUID().toString.take(8) // Generate unique IDs for tests + private val topicId = s"integration-topic-$testId" + private val subscriptionId = s"integration-sub-$testId" + + // Components under test + private var pubSubManager: PubSubManager = _ + private var pubSubAdmin: PubSubAdmin = _ + private var publisher: PubSubPublisher = _ + private var subscriber: PubSubSubscriber = _ + + override def beforeAll(): Unit = { + // Check if the emulator is available + assume( + sys.env.contains("PUBSUB_EMULATOR_HOST"), + "PubSub emulator not available. Set PUBSUB_EMULATOR_HOST environment variable." + ) + + // Create test configuration and components + val config = GcpPubSubConfig.forEmulator(projectId, emulatorHost) + pubSubManager = PubSubManager(config) + pubSubAdmin = PubSubAdmin(config) + + // Create topic and subscription + Try { + pubSubAdmin.createTopic(topicId) + pubSubAdmin.createSubscription(topicId, subscriptionId) + }.recover { case e: Exception => + fail(s"Failed to set up PubSub resources: ${e.getMessage}") + } + + // Get publisher and subscriber + publisher = pubSubManager.getOrCreatePublisher(topicId) + subscriber = pubSubManager.getOrCreateSubscriber(topicId, subscriptionId) + } + + override def afterAll(): Unit = { + // Clean up all resources + Try { + if (publisher != null) publisher.shutdown() + if (subscriber != null) subscriber.shutdown() + if (pubSubAdmin != null) pubSubAdmin.close() + if (pubSubManager != null) pubSubManager.shutdown() + } + } + + "PubSubAdmin" should "create topics and subscriptions idempotent" in { + // Create unique IDs for this test + val testTopicId = s"topic-admin-test-${UUID.randomUUID().toString.take(8)}" + val testSubId = s"sub-admin-test-${UUID.randomUUID().toString.take(8)}" + + // Create topic + pubSubAdmin.createTopic(testTopicId) + + // Call again to test idempotence (should not throw error) + pubSubAdmin.createTopic(testTopicId) + + // Create subscription + pubSubAdmin.createSubscription(testTopicId, testSubId) + + // Call again to test idempotence (should not throw error) + pubSubAdmin.createSubscription(testTopicId, testSubId) + } + + "PubSubAdmin" should "handle creating multiple topics and subscriptions" in { + // Create unique IDs for this test + val testTopicId1 = s"topic-multi-1-${UUID.randomUUID().toString.take(8)}" + val testTopicId2 = s"topic-multi-2-${UUID.randomUUID().toString.take(8)}" + val testSubId1 = s"sub-multi-1-${UUID.randomUUID().toString.take(8)}" + val testSubId2 = s"sub-multi-2-${UUID.randomUUID().toString.take(8)}" + + // Create multiple topics + pubSubAdmin.createTopic(testTopicId1) + pubSubAdmin.createTopic(testTopicId2) + + // Create multiple subscriptions + pubSubAdmin.createSubscription(testTopicId1, testSubId1) + pubSubAdmin.createSubscription(testTopicId2, testSubId2) + } + + "PubSubPublisher and PubSubSubscriber" should "publish and receive messages" in { + // Create a test message + val message = JobSubmissionMessage( + nodeName = "integration-test", + data = Some("Test message for integration testing"), + attributes = Map("test" -> "true") + ) + + // Publish the message + val messageIdFuture = publisher.publish(message) + val messageId = messageIdFuture.get(5, TimeUnit.SECONDS) + messageId should not be null + + // Pull messages + val messages = subscriber.pullMessages(10) + messages.size should be(1) + + // Find our message + val receivedMessage = findMessageByNodeName(messages, "integration-test") + receivedMessage should be(defined) + + // Verify contents + val pubsubMsg = receivedMessage.get + pubsubMsg.getAttributes.getOrElse("nodeName", "") should be("integration-test") + pubsubMsg.getAttributes.getOrElse("test", "") should be("true") + } + + "JobSubmissionMessage" should "work with DummyNode conversion in real environment" in { + // Create a DummyNode + val dummyNode = new DummyNode().setName("dummy-node-test") + + // Convert to message + val message = JobSubmissionMessage.fromDummyNode(dummyNode) + message.nodeName should be("dummy-node-test") + + // Publish the message + val messageId = publisher.publish(message).get(5, TimeUnit.SECONDS) + messageId should not be null + + // Pull and verify + val messages = subscriber.pullMessages(10) + val receivedMessage = findMessageByNodeName(messages, "dummy-node-test") + receivedMessage should be(defined) + + // Verify content + val pubsubMsg = receivedMessage.get + pubsubMsg.getAttributes.getOrElse("nodeName", "") should be("dummy-node-test") + } + + "PubSubManager" should "properly handle multiple publishers and subscribers" in { + // Create unique IDs for this test + val testTopicId = s"topic-multi-test-${UUID.randomUUID().toString.take(8)}" + val testSubId1 = s"sub-multi-test-1-${UUID.randomUUID().toString.take(8)}" + val testSubId2 = s"sub-multi-test-2-${UUID.randomUUID().toString.take(8)}" + + // Create topic and subscriptions + pubSubAdmin.createTopic(testTopicId) + pubSubAdmin.createSubscription(testTopicId, testSubId1) + pubSubAdmin.createSubscription(testTopicId, testSubId2) + + // Get publishers and subscribers + val testPublisher = pubSubManager.getOrCreatePublisher(testTopicId) + val testSubscriber1 = pubSubManager.getOrCreateSubscriber(testTopicId, testSubId1) + val testSubscriber2 = pubSubManager.getOrCreateSubscriber(testTopicId, testSubId2) + + // Publish a message + val message = JobSubmissionMessage("multi-test", Some("Testing multiple subscribers")) + testPublisher.publish(message).get(5, TimeUnit.SECONDS) + + // Both subscribers should receive the message + val messages1 = testSubscriber1.pullMessages(10) + val messages2 = testSubscriber2.pullMessages(10) + + // Verify messages from both subscribers + findMessageByNodeName(messages1, "multi-test") should be(defined) + findMessageByNodeName(messages2, "multi-test") should be(defined) + } + + "PubSubPublisher" should "handle batch publishing" in { + // Create and publish multiple messages + val messageCount = 5 + val messageIds = (1 to messageCount).map { i => + val message = JobSubmissionMessage(s"batch-node-$i", Some(s"Batch message $i")) + publisher.publish(message).get(5, TimeUnit.SECONDS) + } + + // Verify all messages got IDs + messageIds.size should be(messageCount) + messageIds.foreach(_ should not be null) + + // Pull messages + val messages = subscriber.pullMessages(messageCount + 5) // Add buffer + + // Verify all node names are present + val foundNodeNames = messages.map(msg => msg.getAttributes.getOrElse("nodeName", "")).toSet + + // Check each batch message is found + (1 to messageCount).foreach { i => + val nodeName = s"batch-node-$i" + withClue(s"Missing message for node $nodeName: ") { + foundNodeNames should contain(nodeName) + } + } + } + + // Helper method to find a message by node name + private def findMessageByNodeName(messages: Seq[PubSubMessage], nodeName: String): Option[PubSubMessage] = { + messages.find(msg => msg.getAttributes.getOrElse("nodeName", "") == nodeName) + } +} diff --git a/orchestration/src/test/scala/ai/chronon/orchestration/test/pubsub/GcpPubSubSpec.scala b/orchestration/src/test/scala/ai/chronon/orchestration/test/pubsub/GcpPubSubSpec.scala new file mode 100644 index 0000000000..ad7c46691e --- /dev/null +++ b/orchestration/src/test/scala/ai/chronon/orchestration/test/pubsub/GcpPubSubSpec.scala @@ -0,0 +1,446 @@ +package ai.chronon.orchestration.test.pubsub + +import ai.chronon.orchestration.DummyNode +import ai.chronon.orchestration.pubsub._ +import com.google.api.core.ApiFuture +import com.google.api.gax.core.NoCredentialsProvider +import com.google.api.gax.rpc.{NotFoundException, StatusCode} +import com.google.cloud.pubsub.v1.{Publisher, SubscriptionAdminClient, TopicAdminClient} +import com.google.pubsub.v1.{ + PubsubMessage, + PullResponse, + ReceivedMessage, + Subscription, + SubscriptionName, + Topic, + TopicName +} +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar + +import java.util + +/** Unit tests for PubSub components using mocks */ +class GcpPubSubSpec extends AnyFlatSpec with Matchers with MockitoSugar { + + private val notFoundException = new NotFoundException("Not found", null, mock[StatusCode], false) + + "GcpPubSubConfig" should "create production configuration" in { + val config = GcpPubSubConfig.forProduction("test-project") + + config.projectId shouldBe "test-project" + config.channelProvider shouldBe None + config.credentialsProvider shouldBe None + } + + "GcpPubSubConfig" should "create emulator configuration" in { + val config = GcpPubSubConfig.forEmulator("test-project") + + config.projectId shouldBe "test-project" + config.channelProvider shouldBe defined + config.credentialsProvider shouldBe defined + config.credentialsProvider.get.getClass shouldBe NoCredentialsProvider.create().getClass + } + + "JobSubmissionMessage" should "convert to PubsubMessage correctly" in { + val message = JobSubmissionMessage( + nodeName = "test-node", + data = Some("Test data"), + attributes = Map("customKey" -> "customValue") + ) + + val pubsubMessage = message.toPubsubMessage + + pubsubMessage.getAttributesMap.get("nodeName") shouldBe "test-node" + pubsubMessage.getAttributesMap.get("customKey") shouldBe "customValue" + pubsubMessage.getData.toStringUtf8 shouldBe "Test data" + } + + "JobSubmissionMessage" should "create from DummyNode correctly" in { + val node = new DummyNode().setName("test-node") + val message = JobSubmissionMessage.fromDummyNode(node) + + message.nodeName shouldBe "test-node" + message.data shouldBe defined + message.data.get should include("test-node") + + val pubsubMessage = message.toPubsubMessage + pubsubMessage.getAttributesMap.get("nodeName") shouldBe "test-node" + } + + "GcpPubSubPublisher" should "publish messages successfully" in { + // Mock dependencies + val mockPublisher = mock[Publisher] + val mockFuture = mock[ApiFuture[String]] + + // Set up config and topic + val config = GcpPubSubConfig.forEmulator("test-project") + val topicId = "test-topic" + + // Create a test publisher that uses the mock publisher + val publisher = new GcpPubSubPublisher(config, topicId) { + // Expose createPublisher as a test hook and override to return mock + override def createPublisher(): Publisher = mockPublisher + } + + // Set up the mock publisher to return our mock future + when(mockPublisher.publish(any[PubsubMessage])).thenReturn(mockFuture) + + // Create a message and attempt to publish + val message = JobSubmissionMessage("test-node", Some("Test data")) + val resultFuture = publisher.publish(message) + + // Verify publisher was called with message + verify(mockPublisher).publish(any[PubsubMessage]) + + // Cleaning up + publisher.shutdown() + } + + "GcpPubSubAdmin createTopic" should "successfully create a new topic" in { + // Mock the TopicAdminClient + val mockTopicAdmin = mock[TopicAdminClient] + + // Create a mock admin that uses our mock + val admin = new GcpPubSubAdmin(GcpPubSubConfig.forEmulator("test-project")) { + override protected lazy val topicAdminClient: TopicAdminClient = mockTopicAdmin + override protected lazy val subscriptionAdminClient: SubscriptionAdminClient = mock[SubscriptionAdminClient] + } + + // Mock the creation response for a new topic + when(mockTopicAdmin.createTopic(any[TopicName])).thenReturn(mock[Topic]) + + // Create the topic + admin.createTopic("test-topic") + + // Verify createTopic was called + verify(mockTopicAdmin).createTopic(any[TopicName]) + + // Cleanup + admin.close() + } + + "GcpPubSubAdmin createTopic" should "handle the case when a topic already exists" in { + // Mock the TopicAdminClient + val mockTopicAdmin = mock[TopicAdminClient] + + // Create a mock admin that uses our mock + val admin = new GcpPubSubAdmin(GcpPubSubConfig.forEmulator("test-project")) { + override protected lazy val topicAdminClient: TopicAdminClient = mockTopicAdmin + override protected lazy val subscriptionAdminClient: SubscriptionAdminClient = mock[SubscriptionAdminClient] + } + + // Mock the creation response to throw ALREADY_EXISTS exception + val alreadyExistsException = new RuntimeException("ALREADY_EXISTS: Topic already exists") + when(mockTopicAdmin.createTopic(any[TopicName])).thenThrow(alreadyExistsException) + + // Try to create the topic - should not throw exception + admin.createTopic("test-topic") + + // Verify createTopic was called and exception was handled internally + verify(mockTopicAdmin).createTopic(any[TopicName]) + + // Cleanup + admin.close() + } + + "GcpPubSubAdmin createTopic" should "throw exception for errors other than 'already exists'" in { + // Mock the TopicAdminClient + val mockTopicAdmin = mock[TopicAdminClient] + + // Create a mock admin that uses our mock + val admin = new GcpPubSubAdmin(GcpPubSubConfig.forEmulator("test-project")) { + override protected lazy val topicAdminClient: TopicAdminClient = mockTopicAdmin + override protected lazy val subscriptionAdminClient: SubscriptionAdminClient = mock[SubscriptionAdminClient] + } + + // Mock the create response to throw a different type of exception + val otherException = new RuntimeException("PERMISSION_DENIED: Not authorized to create topic") + when(mockTopicAdmin.createTopic(any[TopicName])).thenThrow(otherException) + + // Try to create the topic - should throw the exception + val exception = intercept[RuntimeException] { + admin.createTopic("test-topic") + } + + // Verify the exception is the same as the one we mocked + exception shouldBe otherException + + // Verify createTopic was called + verify(mockTopicAdmin).createTopic(any[TopicName]) + } + + "GcpPubSubAdmin createSubscription" should "successfully create a new subscription" in { + // Mock the SubscriptionAdminClient + val mockSubscriptionAdmin = mock[SubscriptionAdminClient] + + // Create a mock admin that uses our mock + val admin = new GcpPubSubAdmin(GcpPubSubConfig.forEmulator("test-project")) { + override protected lazy val topicAdminClient: TopicAdminClient = mock[TopicAdminClient] + override protected lazy val subscriptionAdminClient: SubscriptionAdminClient = mockSubscriptionAdmin + } + + // Mock the creation response for a new subscription + when( + mockSubscriptionAdmin.createSubscription( + any[SubscriptionName], + any[TopicName], + any(), + any[Int] + )).thenReturn(mock[Subscription]) + + // Create the subscription + admin.createSubscription("test-topic", "test-sub") + + // Verify createSubscription was called + verify(mockSubscriptionAdmin).createSubscription( + any[SubscriptionName], + any[TopicName], + any(), + any[Int] + ) + + // Cleanup + admin.close() + } + + "GcpPubSubAdmin createSubscription" should "handle the case when a subscription already exists" in { + // Mock the SubscriptionAdminClient + val mockSubscriptionAdmin = mock[SubscriptionAdminClient] + + // Create a mock admin that uses our mock + val admin = new GcpPubSubAdmin(GcpPubSubConfig.forEmulator("test-project")) { + override protected lazy val topicAdminClient: TopicAdminClient = mock[TopicAdminClient] + override protected lazy val subscriptionAdminClient: SubscriptionAdminClient = mockSubscriptionAdmin + } + + // Mock the creation response to throw ALREADY_EXISTS exception + val alreadyExistsException = new RuntimeException("ALREADY_EXISTS: Subscription already exists") + when( + mockSubscriptionAdmin.createSubscription( + any[SubscriptionName], + any[TopicName], + any(), + any[Int] + )).thenThrow(alreadyExistsException) + + // Try to create the subscription - should not throw exception + admin.createSubscription("test-topic", "test-sub") + + // Verify createSubscription was called and exception was handled internally + verify(mockSubscriptionAdmin).createSubscription( + any[SubscriptionName], + any[TopicName], + any(), + any[Int] + ) + + // Cleanup + admin.close() + } + + "GcpPubSubAdmin createSubscription" should "throw exception for errors other than 'already exists'" in { + // Mock the SubscriptionAdminClient + val mockSubscriptionAdmin = mock[SubscriptionAdminClient] + + // Create a mock admin that uses our mock + val admin = new GcpPubSubAdmin(GcpPubSubConfig.forEmulator("test-project")) { + override protected lazy val topicAdminClient: TopicAdminClient = mock[TopicAdminClient] + override protected lazy val subscriptionAdminClient: SubscriptionAdminClient = mockSubscriptionAdmin + } + + // Mock the create response to throw a different type of exception + val otherException = new RuntimeException("INVALID_ARGUMENT: Invalid subscription name") + when( + mockSubscriptionAdmin.createSubscription( + any[SubscriptionName], + any[TopicName], + any(), + any[Int] + )).thenThrow(otherException) + + // Try to create the subscription - should throw the exception + val exception = intercept[RuntimeException] { + admin.createSubscription("test-topic", "test-sub") + } + + // Verify the exception is the same as the one we mocked + exception shouldBe otherException + + // Verify createSubscription was called + verify(mockSubscriptionAdmin).createSubscription( + any[SubscriptionName], + any[TopicName], + any(), + any[Int] + ) + } + + "GcpPubSubSubscriber" should "pull messages correctly" in { + // Mock the subscription admin client + val mockSubscriptionAdmin = mock[SubscriptionAdminClient] + + // Mock the pull response + val mockPullResponse = mock[PullResponse] + val mockReceivedMessage = mock[ReceivedMessage] + val mockPubsubMessage = mock[PubsubMessage] + + // Set up the mocks + when(mockReceivedMessage.getMessage).thenReturn(mockPubsubMessage) + when(mockReceivedMessage.getAckId).thenReturn("test-ack-id") + when(mockPullResponse.getReceivedMessagesList).thenReturn(util.Arrays.asList(mockReceivedMessage)) + when(mockSubscriptionAdmin.pull(any[SubscriptionName], any[Int])).thenReturn(mockPullResponse) + + // Create a test configuration + val config = GcpPubSubConfig.forEmulator("test-project") + + // Create a test subscriber that uses our mock admin client + val subscriber = new GcpPubSubSubscriber(config, "test-sub") { + override protected val adminClient: SubscriptionAdminClient = mockSubscriptionAdmin + } + + // Pull messages + val messages = subscriber.pullMessages(10) + + // Verify + messages.size shouldBe 1 + messages.head shouldBe a[PubSubMessage] + + // Verify acknowledge was called + verify(mockSubscriptionAdmin).acknowledge(any[SubscriptionName], any()) + + // Cleanup + subscriber.shutdown() + } + + "GcpPubSubSubscriber" should "throw RuntimeException when there is a pull error" in { + // Mock the subscription admin client + val mockSubscriptionAdmin = mock[SubscriptionAdminClient] + + // Set up the mock to throw an exception when pulling messages + val errorMessage = "Error pulling messages" + when(mockSubscriptionAdmin.pull(any[SubscriptionName], any[Int])) + .thenThrow(new RuntimeException(errorMessage)) + + // Create a test configuration + val config = GcpPubSubConfig.forEmulator("test-project") + + // Create a test subscriber that uses our mock admin client + val subscriber = new GcpPubSubSubscriber(config, "test-sub") { + override protected val adminClient: SubscriptionAdminClient = mockSubscriptionAdmin + } + + // Pull messages - should throw an exception + val exception = intercept[RuntimeException] { + subscriber.pullMessages(10) + } + + // Verify the exception message + exception.getMessage should include(errorMessage) + + // Cleanup + subscriber.shutdown() + } + + "PubSubManager" should "cache publishers and subscribers" in { + // Create mock admin, publisher, and subscriber + val mockAdmin = mock[PubSubAdmin] + val mockPublisher1 = mock[PubSubPublisher] + val mockPublisher2 = mock[PubSubPublisher] + val mockSubscriber1 = mock[PubSubSubscriber] + val mockSubscriber2 = mock[PubSubSubscriber] + + // Configure the mocks - don't need to return values for void methods + doNothing().when(mockAdmin).createTopic(any[String]) + doNothing().when(mockAdmin).createSubscription(any[String], any[String]) + + when(mockPublisher1.topicId).thenReturn("topic1") + when(mockPublisher2.topicId).thenReturn("topic2") + when(mockSubscriber1.subscriptionId).thenReturn("sub1") + when(mockSubscriber2.subscriptionId).thenReturn("sub2") + + // Create a test manager with mocked components + val config = GcpPubSubConfig.forEmulator("test-project") + val manager = new GcpPubSubManager(config) { + override protected val admin: PubSubAdmin = mockAdmin + + // Cache for our test publishers and subscribers + private val testPublishers = Map( + "topic1" -> mockPublisher1, + "topic2" -> mockPublisher2 + ) + + private val testSubscribers = Map( + "sub1" -> mockSubscriber1, + "sub2" -> mockSubscriber2 + ) + + // Override publisher creation to return our mocks + override def getOrCreatePublisher(topicId: String): PubSubPublisher = { + admin.createTopic(topicId) + testPublishers.getOrElse(topicId, { + val pub = mock[PubSubPublisher] + when(pub.topicId).thenReturn(topicId) + pub + }) + } + + // Override subscriber creation to return our mocks + override def getOrCreateSubscriber(topicId: String, subscriptionId: String): PubSubSubscriber = { + admin.createSubscription(topicId, subscriptionId) + testSubscribers.getOrElse(subscriptionId, { + val sub = mock[PubSubSubscriber] + when(sub.subscriptionId).thenReturn(subscriptionId) + sub + }) + } + } + + // Test publisher retrieval - should get the same instances for same topic + val pub1First = manager.getOrCreatePublisher("topic1") + val pub1Second = manager.getOrCreatePublisher("topic1") + val pub2 = manager.getOrCreatePublisher("topic2") + + pub1First shouldBe mockPublisher1 + pub1Second shouldBe mockPublisher1 + pub2 shouldBe mockPublisher2 + + // Test subscriber retrieval - should get same instances for same subscription + val sub1First = manager.getOrCreateSubscriber("topic1", "sub1") + val sub1Second = manager.getOrCreateSubscriber("topic1", "sub1") + val sub2 = manager.getOrCreateSubscriber("topic1", "sub2") + + sub1First shouldBe mockSubscriber1 + sub1Second shouldBe mockSubscriber1 + sub2 shouldBe mockSubscriber2 + + // Verify the admin calls + verify(mockAdmin, times(2)).createTopic("topic1") + verify(mockAdmin).createTopic("topic2") + verify(mockAdmin, times(2)).createSubscription("topic1", "sub1") + verify(mockAdmin).createSubscription("topic1", "sub2") + + // Cleanup + manager.shutdown() + } + + "PubSubManager companion" should "cache managers by config" in { + // Create test configs + val config1 = GcpPubSubConfig.forEmulator("project1") + val config2 = GcpPubSubConfig.forEmulator("project2") // Different project + + // Test manager caching + val manager1 = PubSubManager(config1) + val manager2 = PubSubManager(config1) + val manager3 = PubSubManager(config2) + + manager1 shouldBe theSameInstanceAs(manager2) // Same config should reuse + manager1 should not be theSameInstanceAs(manager3) // Different config = different manager + + // Cleanup + PubSubManager.shutdownAll() + } +} diff --git a/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/activity/NodeExecutionActivityTest.scala b/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/activity/NodeExecutionActivityTest.scala index c52697839d..a3cada62a9 100644 --- a/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/activity/NodeExecutionActivityTest.scala +++ b/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/activity/NodeExecutionActivityTest.scala @@ -1,6 +1,7 @@ package ai.chronon.orchestration.test.temporal.activity import ai.chronon.orchestration.DummyNode +import ai.chronon.orchestration.pubsub.{JobSubmissionMessage, PubSubPublisher} import ai.chronon.orchestration.temporal.activity.{NodeExecutionActivity, NodeExecutionActivityImpl} import ai.chronon.orchestration.temporal.constants.NodeExecutionWorkflowTaskQueue import ai.chronon.orchestration.temporal.workflow.WorkflowOperations @@ -10,6 +11,7 @@ import io.temporal.client.{WorkflowClient, WorkflowOptions} import io.temporal.testing.TestWorkflowEnvironment import io.temporal.worker.Worker import io.temporal.workflow.{Workflow, WorkflowInterface, WorkflowMethod} +import org.mockito.{ArgumentCaptor, ArgumentMatchers} import org.mockito.Mockito.{atLeastOnce, verify, when} import org.scalatest.BeforeAndAfterEach import org.scalatest.flatspec.AnyFlatSpec @@ -20,16 +22,18 @@ import java.lang.{Void => JavaVoid} import java.time.Duration import java.util.concurrent.CompletableFuture -// Test workflow just for activity testing -// This is needed for testing manual completion logic for our activity as it's not supported for +// Test workflows for activity testing +// These are needed for testing manual completion logic for our activities as it's not supported for // test activity environment + +// Workflow for testing triggerDependency @WorkflowInterface -trait TestActivityWorkflow { +trait TestTriggerDependencyWorkflow { @WorkflowMethod def triggerDependency(node: DummyNode): Unit } -class TestActivityWorkflowImpl extends TestActivityWorkflow { +class TestTriggerDependencyWorkflowImpl extends TestTriggerDependencyWorkflow { private val activity = Workflow.newActivityStub( classOf[NodeExecutionActivity], ActivityOptions @@ -43,6 +47,27 @@ class TestActivityWorkflowImpl extends TestActivityWorkflow { } } +// Workflow for testing submitJob +@WorkflowInterface +trait TestSubmitJobWorkflow { + @WorkflowMethod + def submitJob(node: DummyNode): Unit +} + +class TestSubmitJobWorkflowImpl extends TestSubmitJobWorkflow { + private val activity = Workflow.newActivityStub( + classOf[NodeExecutionActivity], + ActivityOptions + .newBuilder() + .setStartToCloseTimeout(Duration.ofSeconds(5)) + .build() + ) + + override def submitJob(node: DummyNode): Unit = { + activity.submitJob(node) + } +} + class NodeExecutionActivityTest extends AnyFlatSpec with Matchers with BeforeAndAfterEach with MockitoSugar { private val workflowOptions = WorkflowOptions @@ -55,26 +80,34 @@ class NodeExecutionActivityTest extends AnyFlatSpec with Matchers with BeforeAnd private var worker: Worker = _ private var workflowClient: WorkflowClient = _ private var mockWorkflowOps: WorkflowOperations = _ - private var testActivityWorkflow: TestActivityWorkflow = _ + private var mockPublisher: PubSubPublisher = _ + private var testTriggerWorkflow: TestTriggerDependencyWorkflow = _ + private var testSubmitWorkflow: TestSubmitJobWorkflow = _ override def beforeEach(): Unit = { testEnv = TemporalTestEnvironmentUtils.getTestWorkflowEnv worker = testEnv.newWorker(NodeExecutionWorkflowTaskQueue.toString) - worker.registerWorkflowImplementationTypes(classOf[TestActivityWorkflowImpl]) + worker.registerWorkflowImplementationTypes( + classOf[TestTriggerDependencyWorkflowImpl], + classOf[TestSubmitJobWorkflowImpl] + ) workflowClient = testEnv.getWorkflowClient - // Create mock workflow operations + // Create mock dependencies mockWorkflowOps = mock[WorkflowOperations] + mockPublisher = mock[PubSubPublisher] + when(mockPublisher.topicId).thenReturn("test-topic") // Create activity with mocked dependencies - val activity = new NodeExecutionActivityImpl(mockWorkflowOps) + val activity = new NodeExecutionActivityImpl(mockWorkflowOps, mockPublisher) worker.registerActivitiesImplementations(activity) // Start the test environment testEnv.start() - // Create test activity workflow - testActivityWorkflow = workflowClient.newWorkflowStub(classOf[TestActivityWorkflow], workflowOptions) + // Create test activity workflows + testTriggerWorkflow = workflowClient.newWorkflowStub(classOf[TestTriggerDependencyWorkflow], workflowOptions) + testSubmitWorkflow = workflowClient.newWorkflowStub(classOf[TestSubmitJobWorkflow], workflowOptions) } override def afterEach(): Unit = { @@ -91,7 +124,7 @@ class NodeExecutionActivityTest extends AnyFlatSpec with Matchers with BeforeAnd when(mockWorkflowOps.startNodeWorkflow(testNode)).thenReturn(completedFuture) // Trigger activity method - testActivityWorkflow.triggerDependency(testNode) + testTriggerWorkflow.triggerDependency(testNode) // Assert verify(mockWorkflowOps).startNodeWorkflow(testNode) @@ -108,7 +141,7 @@ class NodeExecutionActivityTest extends AnyFlatSpec with Matchers with BeforeAnd // Trigger activity and expect it to fail val exception = intercept[RuntimeException] { - testActivityWorkflow.triggerDependency(testNode) + testTriggerWorkflow.triggerDependency(testNode) } // Verify that the exception is propagated correctly @@ -119,26 +152,42 @@ class NodeExecutionActivityTest extends AnyFlatSpec with Matchers with BeforeAnd } it should "submit job successfully" in { - val testActivityEnvironment = TemporalTestEnvironmentUtils.getTestActivityEnv - - // Get the activity stub (interface) to use for testing - val activity = testActivityEnvironment.newActivityStub( - classOf[NodeExecutionActivity], - ActivityOptions - .newBuilder() - .setScheduleToCloseTimeout(Duration.ofSeconds(10)) - .build() - ) + val testNode = new DummyNode().setName("test-node") + val completedFuture = CompletableFuture.completedFuture("message-id-123") - // Create activity implementation with mock workflow operations - val activityImpl = new NodeExecutionActivityImpl(mockWorkflowOps) + // Mock PubSub publisher to return a completed future + when(mockPublisher.publish(ArgumentMatchers.any[JobSubmissionMessage])).thenReturn(completedFuture) - // Register activity implementation with the test environment - testActivityEnvironment.registerActivitiesImplementations(activityImpl) + // Trigger activity method + testSubmitWorkflow.submitJob(testNode) - val testNode = new DummyNode().setName("test-node") + // Use a capture to verify the message passed to the publisher + val messageCaptor = ArgumentCaptor.forClass(classOf[JobSubmissionMessage]) + verify(mockPublisher).publish(messageCaptor.capture()) + + // Verify the message content + val capturedMessage = messageCaptor.getValue + capturedMessage.nodeName should be(testNode.name) + } + + it should "fail when publishing to PubSub fails" in { + val testNode = new DummyNode().setName("failing-node") + val expectedException = new RuntimeException("Failed to publish message") + val failedFuture = new CompletableFuture[String]() + failedFuture.completeExceptionally(expectedException) + + // Mock PubSub publisher to return a failed future + when(mockPublisher.publish(ArgumentMatchers.any[JobSubmissionMessage])).thenReturn(failedFuture) + + // Trigger activity and expect it to fail + val exception = intercept[RuntimeException] { + testSubmitWorkflow.submitJob(testNode) + } + + // Verify that the exception is propagated correctly + exception.getMessage should include("failed") - activity.submitJob(testNode) - testActivityEnvironment.close() + // Verify the message was passed to the publisher + verify(mockPublisher, atLeastOnce()).publish(ArgumentMatchers.any[JobSubmissionMessage]) } } diff --git a/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/workflow/NodeExecutionWorkflowFullDagSpec.scala b/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/workflow/NodeExecutionWorkflowFullDagSpec.scala index c46e769d83..0a7fbe24a3 100644 --- a/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/workflow/NodeExecutionWorkflowFullDagSpec.scala +++ b/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/workflow/NodeExecutionWorkflowFullDagSpec.scala @@ -1,5 +1,6 @@ package ai.chronon.orchestration.test.temporal.workflow +import ai.chronon.orchestration.pubsub.{GcpPubSubMessage, PubSubMessage, PubSubPublisher} import ai.chronon.orchestration.temporal.activity.NodeExecutionActivityImpl import ai.chronon.orchestration.temporal.constants.NodeExecutionWorkflowTaskQueue import ai.chronon.orchestration.temporal.workflow.{ @@ -12,15 +13,21 @@ import io.temporal.api.enums.v1.WorkflowExecutionStatus import io.temporal.client.WorkflowClient import io.temporal.testing.TestWorkflowEnvironment import io.temporal.worker.Worker +import org.mockito.ArgumentMatchers +import org.mockito.Mockito.when import org.scalatest.BeforeAndAfterEach import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar.mock + +import java.util.concurrent.CompletableFuture class NodeExecutionWorkflowFullDagSpec extends AnyFlatSpec with Matchers with BeforeAndAfterEach { private var testEnv: TestWorkflowEnvironment = _ private var worker: Worker = _ private var workflowClient: WorkflowClient = _ + private var mockPublisher: PubSubPublisher = _ private var mockWorkflowOps: WorkflowOperations = _ override def beforeEach(): Unit = { @@ -32,9 +39,14 @@ class NodeExecutionWorkflowFullDagSpec extends AnyFlatSpec with Matchers with Be // Mock workflow operations mockWorkflowOps = new WorkflowOperationsImpl(workflowClient) - // Create activity with mocked dependencies - val activity = new NodeExecutionActivityImpl(mockWorkflowOps) + // Mock PubSub publisher + mockPublisher = mock[PubSubPublisher] + val completedFuture = CompletableFuture.completedFuture("message-id-123") + when(mockPublisher.publish(ArgumentMatchers.any[PubSubMessage])).thenReturn(completedFuture) + when(mockPublisher.topicId).thenReturn("test-topic") + // Create activity with mocked dependencies + val activity = new NodeExecutionActivityImpl(mockWorkflowOps, mockPublisher) worker.registerActivitiesImplementations(activity) // Start the test environment diff --git a/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/workflow/NodeExecutionWorkflowIntegrationSpec.scala b/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/workflow/NodeExecutionWorkflowIntegrationSpec.scala index 3763b7f88f..03d234f253 100644 --- a/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/workflow/NodeExecutionWorkflowIntegrationSpec.scala +++ b/orchestration/src/test/scala/ai/chronon/orchestration/test/temporal/workflow/NodeExecutionWorkflowIntegrationSpec.scala @@ -1,8 +1,8 @@ package ai.chronon.orchestration.test.temporal.workflow +import ai.chronon.orchestration.pubsub.{PubSubAdmin, GcpPubSubConfig, PubSubManager, PubSubPublisher, PubSubSubscriber} import ai.chronon.orchestration.temporal.activity.NodeExecutionActivityFactory import ai.chronon.orchestration.temporal.constants.NodeExecutionWorkflowTaskQueue -import ai.chronon.orchestration.temporal.converter.ThriftPayloadConverter import ai.chronon.orchestration.temporal.workflow.{ NodeExecutionWorkflowImpl, WorkflowOperations, @@ -10,9 +10,7 @@ import ai.chronon.orchestration.temporal.workflow.{ } import ai.chronon.orchestration.test.utils.{TemporalTestEnvironmentUtils, TestNodeUtils} import io.temporal.api.enums.v1.WorkflowExecutionStatus -import io.temporal.client.{WorkflowClient, WorkflowClientOptions} -import io.temporal.common.converter.DefaultDataConverter -import io.temporal.serviceclient.WorkflowServiceStubs +import io.temporal.client.WorkflowClient import io.temporal.worker.WorkerFactory import org.scalatest.BeforeAndAfterAll import org.scalatest.flatspec.AnyFlatSpec @@ -20,53 +18,132 @@ import org.scalatest.matchers.should.Matchers /** This will trigger workflow runs on the local temporal server, so the pre-requisite would be to have the * temporal service running locally using `temporal server start-dev` + * + * For Pub/Sub testing, you also need: + * 1. Start the Pub/Sub emulator: gcloud beta emulators pubsub start --project=test-project + * 2. Set environment variable: export PUBSUB_EMULATOR_HOST=localhost:8085 */ class NodeExecutionWorkflowIntegrationSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll { + // Pub/Sub test configuration + private val projectId = "test-project" + private val topicId = "test-topic" + private val subscriptionId = "test-subscription" + + // Temporal variables private var workflowClient: WorkflowClient = _ private var workflowOperations: WorkflowOperations = _ private var factory: WorkerFactory = _ + // PubSub variables + private var pubSubManager: PubSubManager = _ + private var publisher: PubSubPublisher = _ + private var subscriber: PubSubSubscriber = _ + private var admin: PubSubAdmin = _ + override def beforeAll(): Unit = { - workflowClient = TemporalTestEnvironmentUtils.getLocalWorkflowClient + // Set up Pub/Sub emulator resources + setupPubSubResources() + // Set up Temporal + workflowClient = TemporalTestEnvironmentUtils.getLocalWorkflowClient workflowOperations = new WorkflowOperationsImpl(workflowClient) - factory = WorkerFactory.newInstance(workflowClient) // Setup worker for node workflow execution val worker = factory.newWorker(NodeExecutionWorkflowTaskQueue.toString) worker.registerWorkflowImplementationTypes(classOf[NodeExecutionWorkflowImpl]) - worker.registerActivitiesImplementations(NodeExecutionActivityFactory.create(workflowClient)) + + // Create and register activity with PubSub configured + val activity = NodeExecutionActivityFactory.create(workflowClient, publisher) + worker.registerActivitiesImplementations(activity) // Start all registered Workers. The Workers will start polling the Task Queue. factory.start() } + private def setupPubSubResources(): Unit = { + val emulatorHost = sys.env.getOrElse("PUBSUB_EMULATOR_HOST", "localhost:8085") + val config = GcpPubSubConfig.forEmulator(projectId, emulatorHost) + + // Create necessary PubSub components + pubSubManager = PubSubManager(config) + admin = PubSubAdmin(config) + + // Create the topic and subscription + admin.createTopic(topicId) + admin.createSubscription(topicId, subscriptionId) + + // Get publisher and subscriber + publisher = pubSubManager.getOrCreatePublisher(topicId) + subscriber = pubSubManager.getOrCreateSubscriber(topicId, subscriptionId) + } + override def afterAll(): Unit = { - factory.shutdown() + // Clean up Temporal resources + if (factory != null) { + factory.shutdown() + } + + // Clean up Pub/Sub resources + try { + publisher.shutdown() + subscriber.shutdown() + admin.close() + pubSubManager.shutdown() + + // Also shutdown the manager to free all resources + PubSubManager.shutdownAll() + } catch { + case e: Exception => println(s"Error during PubSub cleanup: ${e.getMessage}") + } } - it should "handle simple node with one level deep correctly" in { + it should "handle simple node with one level deep correctly and publish messages to Pub/Sub" in { // Trigger workflow and wait for it to complete workflowOperations.startNodeWorkflow(TestNodeUtils.getSimpleNode).get() - // Verify that all node workflows are started and finished successfully - for (dependentNode <- Array("dep1", "dep2", "main")) { + // Expected nodes + val expectedNodes = Array("dep1", "dep2", "main") + + // Verify that all dependent node workflows are started and finished successfully + for (dependentNode <- expectedNodes) { workflowOperations.getWorkflowStatus(s"node-execution-${dependentNode}") should be( WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_COMPLETED) } + + // Verify Pub/Sub messages + val messages = subscriber.pullMessages() + + // Verify we received the expected number of messages + messages.size should be(expectedNodes.length) + + // Verify each node has a message + val nodeNames = messages.map(_.getAttributes.getOrElse("nodeName", "")) + nodeNames should contain allElementsOf (expectedNodes) } - it should "handle complex node with multiple levels deep correctly" in { + it should "handle complex node with multiple levels deep correctly and publish messages to Pub/Sub" in { // Trigger workflow and wait for it to complete workflowOperations.startNodeWorkflow(TestNodeUtils.getComplexNode).get() + // Expected nodes + val expectedNodes = Array("StagingQuery1", "StagingQuery2", "GroupBy1", "GroupBy2", "Join", "Derivation") + // Verify that all dependent node workflows are started and finished successfully - // Activity for Derivation node should trigger all downstream node workflows - for (dependentNode <- Array("StagingQuery1", "StagingQuery2", "GroupBy1", "GroupBy2", "Join", "Derivation")) { + for (dependentNode <- expectedNodes) { workflowOperations.getWorkflowStatus(s"node-execution-${dependentNode}") should be( WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_COMPLETED) } + + // Verify Pub/Sub messages + val messages = subscriber.pullMessages() + + // Verify we received the expected number of messages + messages.size should be(expectedNodes.length) + + // Verify each node has a message + val nodeNames = messages.map(_.getAttributes.getOrElse("nodeName", "")) + nodeNames should contain allElementsOf (expectedNodes) } } diff --git a/tools/build_rules/dependencies/maven_repository.bzl b/tools/build_rules/dependencies/maven_repository.bzl index 14b0a34b6b..c964a14f1f 100644 --- a/tools/build_rules/dependencies/maven_repository.bzl +++ b/tools/build_rules/dependencies/maven_repository.bzl @@ -156,6 +156,10 @@ maven_repository = repository( "com.google.cloud:google-cloud-bigtable-emulator:0.178.0", "com.google.cloud.hosted.kafka:managed-kafka-auth-login-handler:1.0.3", "com.google.cloud:google-cloud-spanner:6.86.0", + "com.google.api:api-common:2.46.1", + "com.google.api:gax:2.60.0", + "com.google.api:gax-grpc:2.60.0", + "com.google.api.grpc:proto-google-cloud-pubsub-v1:1.120.0", # Flink "org.apache.flink:flink-metrics-dropwizard:1.17.0",