diff --git a/.vscode/settings.json b/.vscode/settings.json index eddbb8c0fb..1c17f1d0a0 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -22,6 +22,14 @@ "rust/experimental/query_abstraction/Cargo.toml", "rust/experimental/query_engine/Cargo.toml" ], + + // Enable specific features for rust-analyzer + "rust-analyzer.cargo.features": [ + "experimental-exporters", + "geneva-exporter", + "azure-monitor-exporter", + "azure-identity-auth-extension" + ], // Exclude Rust build artifacts from file watching and search // to improve performance and reduce noise in search results. diff --git a/rust/otap-dataflow/Cargo.toml b/rust/otap-dataflow/Cargo.toml index c6c0d44579..966eabe2e8 100644 --- a/rust/otap-dataflow/Cargo.toml +++ b/rust/otap-dataflow/Cargo.toml @@ -209,6 +209,7 @@ experimental-tls = ["otap-df-otap/experimental-tls", "dep:rustls"] experimental-exporters = ["otap-df-otap/experimental-exporters"] geneva-exporter = ["otap-df-otap/geneva-exporter"] azure-monitor-exporter = ["otap-df-otap/azure-monitor-exporter"] +azure-identity-auth-extension = ["otap-df-otap/azure-identity-auth-extension"] # Experimental processors (opt-in) experimental-processors = ["otap-df-otap/experimental-processors"] condense-attributes-processor = ["otap-df-otap/condense-attributes-processor"] diff --git a/rust/otap-dataflow/crates/config/src/node.rs b/rust/otap-dataflow/crates/config/src/node.rs index 80a06d4de8..e0af67ba52 100644 --- a/rust/otap-dataflow/crates/config/src/node.rs +++ b/rust/otap-dataflow/crates/config/src/node.rs @@ -104,6 +104,8 @@ pub enum NodeKind { Processor, /// A sink of signals Exporter, + /// An extension providing auxiliary services (no signal processing) + Extension, // ToDo(LQ) : Add more node kinds as needed. // A connector between two pipelines @@ -118,6 +120,7 @@ impl From for Cow<'static, str> { NodeKind::Receiver => "receiver".into(), NodeKind::Processor => "processor".into(), NodeKind::Exporter => "exporter".into(), + NodeKind::Extension => "extension".into(), NodeKind::ProcessorChain => "processor_chain".into(), } } diff --git a/rust/otap-dataflow/crates/config/src/urn.rs b/rust/otap-dataflow/crates/config/src/urn.rs index 24b055e1c2..258089531a 100644 --- a/rust/otap-dataflow/crates/config/src/urn.rs +++ b/rust/otap-dataflow/crates/config/src/urn.rs @@ -98,6 +98,7 @@ pub fn validate_plugin_urn(raw: &str, expected_kind: NodeKind) -> Result<(), Err NodeKind::Receiver => "receiver", NodeKind::Processor | NodeKind::ProcessorChain => "processor", NodeKind::Exporter => "exporter", + NodeKind::Extension => "extension", }; if last != expected_suffix { return Err(Error::InvalidUserConfig { diff --git a/rust/otap-dataflow/crates/engine-macros/src/lib.rs b/rust/otap-dataflow/crates/engine-macros/src/lib.rs index dcad1a7ef3..e80be0c111 100644 --- a/rust/otap-dataflow/crates/engine-macros/src/lib.rs +++ b/rust/otap-dataflow/crates/engine-macros/src/lib.rs @@ -58,7 +58,7 @@ impl Parse for PipelineFactoryArgs { /// The individual factory types are imported internally by the macro. /// /// This generates: -/// - Distributed slices for receiver, processor, and exporter factories (prefixed) +/// - Distributed slices for receiver, processor, exporter, and extension factories (prefixed) /// - Proper initialization of the FACTORY_REGISTRY with lazy loading /// - Helper functions to access factory maps (prefixed) #[proc_macro_attribute] @@ -75,6 +75,7 @@ pub fn pipeline_factory(args: TokenStream, input: TokenStream) -> TokenStream { let receiver_factories_name = quote::format_ident!("{}_RECEIVER_FACTORIES", prefix); let processor_factories_name = quote::format_ident!("{}_PROCESSOR_FACTORIES", prefix); let exporter_factories_name = quote::format_ident!("{}_EXPORTER_FACTORIES", prefix); + let extension_factories_name = quote::format_ident!("{}_EXTENSION_FACTORIES", prefix); let get_receiver_factory_map_name = quote::format_ident!( "get_{}_receiver_factory_map", prefix.to_string().to_lowercase() @@ -87,6 +88,10 @@ pub fn pipeline_factory(args: TokenStream, input: TokenStream) -> TokenStream { "get_{}_exporter_factory_map", prefix.to_string().to_lowercase() ); + let get_extension_factory_map_name = quote::format_ident!( + "get_{}_extension_factory_map", + prefix.to_string().to_lowercase() + ); let output = quote! { /// A slice of receiver factories. @@ -101,14 +106,19 @@ pub fn pipeline_factory(args: TokenStream, input: TokenStream) -> TokenStream { #[::otap_df_engine::distributed_slice] pub static #exporter_factories_name: [::otap_df_engine::ExporterFactory<#pdata_type>] = [..]; + /// A slice of extension factories. + #[::otap_df_engine::distributed_slice] + pub static #extension_factories_name: [::otap_df_engine::ExtensionFactory<#pdata_type>] = [..]; + /// The factory registry instance. #registry_vis static #registry_name: std::sync::LazyLock> = std::sync::LazyLock::new(|| { // Reference build_registry to avoid unused import warning, even though we don't call it let _ = build_factory::<#pdata_type>; - PipelineFactory::new( + PipelineFactory::with_extensions( &#receiver_factories_name, &#processor_factories_name, &#exporter_factories_name, + &#extension_factories_name, ) }); @@ -126,6 +136,11 @@ pub fn pipeline_factory(args: TokenStream, input: TokenStream) -> TokenStream { pub fn #get_exporter_factory_map_name() -> &'static std::collections::HashMap<&'static str, ::otap_df_engine::ExporterFactory<#pdata_type>> { #registry_name.get_exporter_factory_map() } + + /// Gets the extension factory map, initializing it if necessary. + pub fn #get_extension_factory_map_name() -> &'static std::collections::HashMap<&'static str, ::otap_df_engine::ExtensionFactory<#pdata_type>> { + #registry_name.get_extension_factory_map() + } }; output.into() diff --git a/rust/otap-dataflow/crates/engine/src/config.rs b/rust/otap-dataflow/crates/engine/src/config.rs index b204c36a82..985d39060f 100644 --- a/rust/otap-dataflow/crates/engine/src/config.rs +++ b/rust/otap-dataflow/crates/engine/src/config.rs @@ -79,6 +79,18 @@ pub struct ExporterConfig { pub input_pdata_channel: PdataChannelConfig, } +/// Generic configuration for an extension. +/// +/// Extensions are special components that don't process pdata, so they only have +/// a control channel configuration. +#[derive(Clone, Debug)] +pub struct ExtensionConfig { + /// Name of the extension. + pub name: NodeId, + /// Configuration for control channel. + pub control_channel: ControlChannelConfig, +} + impl ReceiverConfig { /// Creates a new receiver configuration with the given name and default channel capacity. pub fn new(name: T) -> Self @@ -137,3 +149,19 @@ impl ExporterConfig { } } } + +impl ExtensionConfig { + /// Creates a new extension configuration with the given name and default channel capacity. + #[must_use] + pub fn new(name: T) -> Self + where + T: Into, + { + ExtensionConfig { + name: name.into(), + control_channel: ControlChannelConfig { + capacity: DEFAULT_CONTROL_CHANNEL_CAPACITY, + }, + } + } +} diff --git a/rust/otap-dataflow/crates/engine/src/effect_handler.rs b/rust/otap-dataflow/crates/engine/src/effect_handler.rs index 5c58a9597b..6db10f0115 100644 --- a/rust/otap-dataflow/crates/engine/src/effect_handler.rs +++ b/rust/otap-dataflow/crates/engine/src/effect_handler.rs @@ -5,6 +5,8 @@ use crate::control::{AckMsg, NackMsg, PipelineControlMsg, PipelineCtrlMsgSender}; use crate::error::Error; +use crate::extensions::ExtensionTrait; +use crate::extensions::registry::{ExtensionError, ExtensionRegistry}; use crate::node::NodeId; use otap_df_channel::error::SendError; use otap_df_telemetry::error::Error as TelemetryError; @@ -25,6 +27,8 @@ pub(crate) struct EffectHandlerCore { #[allow(dead_code)] // Will be used in the future. ToDo report metrics from channel and messages. pub(crate) metrics_reporter: MetricsReporter, + /// Registry of extension traits for capability lookup. + pub(crate) extension_registry: Option, } impl EffectHandlerCore { @@ -34,6 +38,7 @@ impl EffectHandlerCore { node_id, pipeline_ctrl_msg_sender: None, metrics_reporter, + extension_registry: None, } } @@ -45,12 +50,50 @@ impl EffectHandlerCore { self.pipeline_ctrl_msg_sender = Some(pipeline_ctrl_msg_sender); } + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: ExtensionRegistry) { + self.extension_registry = Some(registry); + } + /// Returns the id of the node associated with this effect handler. #[must_use] pub(crate) fn node_id(&self) -> NodeId { self.node_id.clone() } + /// Gets an extension trait implementation by extension name. + /// + /// This allows components to look up capabilities provided by extensions. + /// + /// # Type Parameters + /// + /// * `T` - The trait type (e.g., `dyn BearerTokenProvider`). Must implement `ExtensionTrait`. + /// + /// # Errors + /// + /// Returns `ExtensionError::NotFound` if no extension with that name exists or if the + /// extension registry has not been set. + /// Returns `ExtensionError::TraitNotImplemented` if the extension doesn't implement the trait. + /// + /// # Example + /// + /// ```ignore + /// let token_provider: &dyn BearerTokenProvider = effect_handler + /// .get_extension::("azure_auth")?; + /// let token = token_provider.get_token(); + /// ``` + pub(crate) fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + match &self.extension_registry { + Some(registry) => registry.get_trait::(name), + None => Err(ExtensionError::NotFound { + name: name.to_string(), + }), + } + } + /// Print an info message to stdout. /// /// This method provides a standardized way for all nodes in the pipeline diff --git a/rust/otap-dataflow/crates/engine/src/error.rs b/rust/otap-dataflow/crates/engine/src/error.rs index 6810f2a8db..9b529d8a4b 100644 --- a/rust/otap-dataflow/crates/engine/src/error.rs +++ b/rust/otap-dataflow/crates/engine/src/error.rs @@ -95,6 +95,28 @@ impl fmt::Display for ProcessorErrorKind { } } +/// High-level classification for extension failures to aid troubleshooting. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum ExtensionErrorKind { + /// Errors caused by invalid or incomplete configuration detected at runtime. + Configuration, + /// Errors raised while shutting down an extension. + Shutdown, + /// Catch-all for extension failures that do not fit other categories. + Other, +} + +impl fmt::Display for ExtensionErrorKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let label = match self { + ExtensionErrorKind::Configuration => "configuration", + ExtensionErrorKind::Shutdown => "shutdown", + ExtensionErrorKind::Other => "other", + }; + write!(f, "{label}") + } +} + /// Formats the source chain of an error into a single display string. #[must_use] pub fn format_error_sources(error: &(dyn std::error::Error + 'static)) -> String { @@ -323,6 +345,36 @@ pub enum Error { plugin_urn: NodeUrn, }, + /// The specified extension already exists in the pipeline. + #[error("The extension `{extension}` already exists")] + ExtensionAlreadyExists { + /// The name of the extension that already exists. + extension: NodeId, + }, + + /// A wrapper for the extension errors. + #[error("An extension error occurred in node {extension} ({kind}): {error}{source_detail}")] + ExtensionError { + /// The name of the extension that encountered the error. + extension: NodeId, + + /// High-level classification for the extension failure. + kind: ExtensionErrorKind, + + /// The error that occurred. + error: String, + + /// Pre-formatted representation of the source chain used when rendering the error. + source_detail: String, + }, + + /// Unknown extension plugin. + #[error("Unknown extension plugin `{plugin_urn}`")] + UnknownExtension { + /// The name of the unknown extension plugin. + plugin_urn: NodeUrn, + }, + /// Unknown node. #[error("Unknown node `{node}`")] UnknownNode { @@ -424,6 +476,8 @@ impl Error { Error::ConfigError(_) => "ConfigError", Error::ExporterAlreadyExists { .. } => "ExporterAlreadyExists", Error::ExporterError { .. } => "ExporterError", + Error::ExtensionAlreadyExists { .. } => "ExtensionAlreadyExists", + Error::ExtensionError { .. } => "ExtensionError", Error::InternalError { .. } => "InternalError", Error::InvalidHyperEdge { .. } => "InvalidHyperEdge", Error::IoError { .. } => "IoError", @@ -443,6 +497,7 @@ impl Error { Error::SpmcSharedNotSupported { .. } => "SpmcSharedNotSupported", Error::TooManyNodes {} => "TooManyNodes", Error::UnknownExporter { .. } => "UnknownExporter", + Error::UnknownExtension { .. } => "UnknownExtension", Error::UnknownNode { .. } => "UnknownNode", Error::UnknownOutPort { .. } => "UnknownPort", Error::UnknownProcessor { .. } => "UnknownProcessor", diff --git a/rust/otap-dataflow/crates/engine/src/exporter.rs b/rust/otap-dataflow/crates/engine/src/exporter.rs index d5b4601c02..beb9c448d5 100644 --- a/rust/otap-dataflow/crates/engine/src/exporter.rs +++ b/rust/otap-dataflow/crates/engine/src/exporter.rs @@ -14,6 +14,7 @@ use crate::context::PipelineContext; use crate::control::{Controllable, NodeControlMsg, PipelineCtrlMsgSender}; use crate::entity_context::NodeTelemetryGuard; use crate::error::{Error, ExporterErrorKind}; +use crate::extensions::ExtensionRegistry; use crate::local::exporter as local; use crate::local::message::{LocalReceiver, LocalSender}; use crate::message; @@ -51,6 +52,8 @@ pub enum ExporterWrapper { pdata_receiver: Option>, /// Telemetry guard for node lifecycle cleanup. telemetry: Option, + /// Extension registry for looking up extension capabilities. + extension_registry: Option, }, /// An exporter with a `Send` implementation. Shared { @@ -70,6 +73,8 @@ pub enum ExporterWrapper { pdata_receiver: Option>, /// Telemetry guard for node lifecycle cleanup. telemetry: Option, + /// Extension registry for looking up extension capabilities. + extension_registry: Option, }, } @@ -110,6 +115,7 @@ impl ExporterWrapper { control_receiver: LocalReceiver::mpsc(control_receiver), pdata_receiver: None, // This will be set later telemetry: None, + extension_registry: None, } } @@ -136,6 +142,19 @@ impl ExporterWrapper { control_receiver: SharedReceiver::mpsc(control_receiver), pdata_receiver: None, // This will be set later telemetry: None, + extension_registry: None, + } + } + + /// Sets the extension registry for this exporter to use at runtime. + pub fn set_extension_registry(&mut self, registry: ExtensionRegistry) { + match self { + ExporterWrapper::Local { + extension_registry, .. + } => *extension_registry = Some(registry), + ExporterWrapper::Shared { + extension_registry, .. + } => *extension_registry = Some(registry), } } @@ -149,6 +168,7 @@ impl ExporterWrapper { control_sender, control_receiver, pdata_receiver, + extension_registry, .. } => ExporterWrapper::Local { node_id, @@ -159,6 +179,7 @@ impl ExporterWrapper { control_receiver, pdata_receiver, telemetry: Some(guard), + extension_registry, }, ExporterWrapper::Shared { node_id, @@ -168,6 +189,7 @@ impl ExporterWrapper { control_sender, control_receiver, pdata_receiver, + extension_registry, .. } => ExporterWrapper::Shared { node_id, @@ -178,6 +200,7 @@ impl ExporterWrapper { control_receiver, pdata_receiver, telemetry: Some(guard), + extension_registry, }, } } @@ -205,7 +228,7 @@ impl ExporterWrapper { exporter, pdata_receiver, telemetry, - .. + extension_registry, } => { let (control_sender, control_receiver) = wrap_control_channel_metrics::( @@ -227,6 +250,7 @@ impl ExporterWrapper { control_receiver, pdata_receiver, telemetry, + extension_registry, } } ExporterWrapper::Shared { @@ -238,7 +262,7 @@ impl ExporterWrapper { exporter, pdata_receiver, telemetry, - .. + extension_registry, } => { let (control_sender, control_receiver) = wrap_control_channel_metrics::( @@ -260,6 +284,7 @@ impl ExporterWrapper { control_receiver, pdata_receiver, telemetry, + extension_registry, } } } @@ -278,11 +303,15 @@ impl ExporterWrapper { exporter, control_receiver, pdata_receiver, + extension_registry, .. }, metrics_reporter, ) => { let mut effect_handler = local::EffectHandler::new(node_id, metrics_reporter); + if let Some(registry) = extension_registry { + effect_handler.set_extension_registry(registry); + } let pdata_rx = pdata_receiver.ok_or_else(|| Error::ExporterError { exporter: effect_handler.exporter_id(), kind: ExporterErrorKind::Configuration, @@ -302,11 +331,15 @@ impl ExporterWrapper { exporter, control_receiver, pdata_receiver, + extension_registry, .. }, metrics_reporter, ) => { let mut effect_handler = shared::EffectHandler::new(node_id, metrics_reporter); + if let Some(registry) = extension_registry { + effect_handler.set_extension_registry(registry); + } let pdata_rx = pdata_receiver.ok_or_else(|| Error::ExporterError { exporter: effect_handler.exporter_id(), kind: ExporterErrorKind::Configuration, diff --git a/rust/otap-dataflow/crates/engine/src/extension.rs b/rust/otap-dataflow/crates/engine/src/extension.rs new file mode 100644 index 0000000000..cdbf3057a6 --- /dev/null +++ b/rust/otap-dataflow/crates/engine/src/extension.rs @@ -0,0 +1,579 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Extension wrapper used to provide a unified interface to the pipeline engine that abstracts over +//! the fact that extension implementations may be `!Send` or `Send`. +//! +//! For more details on the `!Send` implementation of an extension, see [`local::Extension`]. +//! See [`shared::Extension`] for the Send implementation. + +use crate::channel_metrics::ChannelMetricsRegistry; +use crate::channel_mode::{LocalMode, SharedMode, wrap_control_channel_metrics}; +use crate::config::ExtensionConfig; +use crate::context::PipelineContext; +use crate::control::{Controllable, NodeControlMsg, PipelineCtrlMsgSender}; +use crate::entity_context::NodeTelemetryGuard; +use crate::error::Error; +use crate::extensions::ExtensionRegistry; +use crate::extensions::registry::ExtensionTraits; +use crate::local::extension as local; +use crate::local::message::{LocalReceiver, LocalSender}; +use crate::message; +use crate::message::{Receiver, Sender}; +use crate::node::{Node, NodeId}; +use crate::shared::extension as shared; +use crate::shared::message::{SharedReceiver, SharedSender}; +use crate::terminal_state::TerminalState; +use otap_df_channel::error::SendError; +use otap_df_channel::mpsc; +use otap_df_config::node::NodeUserConfig; +use otap_df_telemetry::reporter::MetricsReporter; +use std::sync::Arc; + +/// A wrapper for the extension that allows for both `Send` and `!Send` effect handlers. +/// +/// Note: This is useful for creating a single interface for the extension regardless of their +/// 'sendability'. +pub enum ExtensionWrapper { + /// An extension with a `!Send` implementation. + Local { + /// Index identifier for the node. + node_id: NodeId, + /// The user configuration for the node, including its name and channel settings. + user_config: Arc, + /// The runtime configuration for the extension. + runtime_config: ExtensionConfig, + /// The extension instance. + extension: Box>, + /// Cast functions for this extension's trait implementations. + /// Taken during pipeline initialization to build the central registry. + extension_traits: Option, + /// A sender for control messages. + control_sender: LocalSender>, + /// A receiver for control messages. + control_receiver: LocalReceiver>, + /// Telemetry guard for node lifecycle cleanup. + telemetry: Option, + /// Extension registry for accessing extension traits. + extension_registry: Option, + }, + /// An extension with a `Send` implementation. + Shared { + /// Index identifier for the node. + node_id: NodeId, + /// The user configuration for the node, including its name and channel settings. + user_config: Arc, + /// The runtime configuration for the extension. + runtime_config: ExtensionConfig, + /// The extension instance. + extension: Box + Send>, + /// Cast functions for this extension's trait implementations. + /// Taken during pipeline initialization to build the central registry. + extension_traits: Option, + /// A sender for control messages. + control_sender: SharedSender>, + /// A receiver for control messages. + control_receiver: SharedReceiver>, + /// Telemetry guard for node lifecycle cleanup. + telemetry: Option, + /// Extension registry for accessing extension traits. + extension_registry: Option, + }, +} + +#[async_trait::async_trait(?Send)] +impl Controllable for ExtensionWrapper { + /// Returns the control message sender for the extension. + fn control_sender(&self) -> Sender> { + match self { + ExtensionWrapper::Local { control_sender, .. } => Sender::Local(control_sender.clone()), + ExtensionWrapper::Shared { control_sender, .. } => { + Sender::Shared(control_sender.clone()) + } + } + } +} + +impl ExtensionWrapper { + /// Creates a new local `ExtensionWrapper` with the given extension and configuration (!Send + /// implementation). + /// + /// # Arguments + /// + /// * `extension` - The extension instance + /// * `extension_casters` - The extension traits bundle for registry lookups + /// * `node_id` - The node identifier + /// * `user_config` - The user configuration + /// * `config` - The extension runtime configuration + pub fn local( + extension: E, + extension_traits: ExtensionTraits, + node_id: NodeId, + user_config: Arc, + config: &ExtensionConfig, + ) -> Self + where + E: local::Extension + 'static, + { + let (control_sender, control_receiver) = + mpsc::Channel::new(config.control_channel.capacity); + + ExtensionWrapper::Local { + node_id, + user_config, + runtime_config: config.clone(), + extension: Box::new(extension), + extension_traits: Some(extension_traits), + control_sender: LocalSender::mpsc(control_sender), + control_receiver: LocalReceiver::mpsc(control_receiver), + telemetry: None, + extension_registry: None, + } + } + + /// Creates a new shared `ExtensionWrapper` with the given extension and configuration (Send + /// implementation). + /// + /// # Arguments + /// + /// * `extension` - The extension instance + /// * `extension_casters` - The extension traits bundle for registry lookups + /// * `node_id` - The node identifier + /// * `user_config` - The user configuration + /// * `config` - The extension runtime configuration + pub fn shared( + extension: E, + extension_traits: ExtensionTraits, + node_id: NodeId, + user_config: Arc, + config: &ExtensionConfig, + ) -> Self + where + E: shared::Extension + Send + 'static, + { + let (control_sender, control_receiver) = + tokio::sync::mpsc::channel(config.control_channel.capacity); + + ExtensionWrapper::Shared { + node_id, + user_config, + runtime_config: config.clone(), + extension: Box::new(extension), + extension_traits: Some(extension_traits), + control_sender: SharedSender::mpsc(control_sender), + control_receiver: SharedReceiver::mpsc(control_receiver), + telemetry: None, + extension_registry: None, + } + } + + /// Sets the extension registry for this extension. + pub fn set_extension_registry(&mut self, registry: ExtensionRegistry) { + match self { + ExtensionWrapper::Local { + extension_registry, .. + } => *extension_registry = Some(registry), + ExtensionWrapper::Shared { + extension_registry, .. + } => *extension_registry = Some(registry), + } + } + + pub(crate) fn with_node_telemetry_guard(self, guard: NodeTelemetryGuard) -> Self { + match self { + ExtensionWrapper::Local { + node_id, + user_config, + runtime_config, + extension, + extension_traits, + control_sender, + control_receiver, + extension_registry, + .. + } => ExtensionWrapper::Local { + node_id, + user_config, + runtime_config, + extension, + extension_traits, + control_sender, + control_receiver, + telemetry: Some(guard), + extension_registry, + }, + ExtensionWrapper::Shared { + node_id, + user_config, + runtime_config, + extension, + extension_traits, + control_sender, + control_receiver, + extension_registry, + .. + } => ExtensionWrapper::Shared { + node_id, + user_config, + runtime_config, + extension, + extension_traits, + control_sender, + control_receiver, + telemetry: Some(guard), + extension_registry, + }, + } + } + + pub(crate) fn take_telemetry_guard(&mut self) -> Option { + match self { + ExtensionWrapper::Local { telemetry, .. } => telemetry.take(), + ExtensionWrapper::Shared { telemetry, .. } => telemetry.take(), + } + } + + pub(crate) fn with_control_channel_metrics( + self, + pipeline_ctx: &PipelineContext, + channel_metrics: &mut ChannelMetricsRegistry, + channel_metrics_enabled: bool, + ) -> Self { + match self { + ExtensionWrapper::Local { + node_id, + runtime_config, + control_sender, + control_receiver, + user_config, + extension, + extension_traits, + telemetry, + extension_registry, + } => { + let (control_sender, control_receiver) = + wrap_control_channel_metrics::( + &node_id, + pipeline_ctx, + channel_metrics, + channel_metrics_enabled, + runtime_config.control_channel.capacity as u64, + control_sender, + control_receiver, + ); + + ExtensionWrapper::Local { + node_id, + user_config, + runtime_config, + extension, + extension_traits, + control_sender, + control_receiver, + telemetry, + extension_registry, + } + } + ExtensionWrapper::Shared { + node_id, + runtime_config, + control_sender, + control_receiver, + user_config, + extension, + extension_traits, + telemetry, + extension_registry, + } => { + let (control_sender, control_receiver) = + wrap_control_channel_metrics::( + &node_id, + pipeline_ctx, + channel_metrics, + channel_metrics_enabled, + runtime_config.control_channel.capacity as u64, + control_sender, + control_receiver, + ); + + ExtensionWrapper::Shared { + node_id, + user_config, + runtime_config, + extension, + extension_traits, + control_sender, + control_receiver, + telemetry, + extension_registry, + } + } + } + } + + /// Starts the extension and begins its operation. + pub async fn start( + self, + pipeline_ctrl_msg_tx: PipelineCtrlMsgSender, + metrics_reporter: MetricsReporter, + ) -> Result { + match (self, metrics_reporter) { + ( + ExtensionWrapper::Local { + node_id, + extension, + control_receiver, + extension_registry, + .. + }, + metrics_reporter, + ) => { + let mut effect_handler = local::EffectHandler::new(node_id, metrics_reporter); + if let Some(registry) = extension_registry { + effect_handler.set_extension_registry(registry); + } + effect_handler + .core + .set_pipeline_ctrl_msg_sender(pipeline_ctrl_msg_tx); + // Extensions only receive control messages, no pdata + // Create a dummy pdata receiver that will never receive anything + let (_dummy_tx, dummy_rx) = mpsc::Channel::::new(1); + let message_channel = message::MessageChannel::new( + Receiver::Local(control_receiver), + Receiver::Local(LocalReceiver::mpsc(dummy_rx)), + ); + extension.start(message_channel, effect_handler).await + } + ( + ExtensionWrapper::Shared { + node_id, + extension, + control_receiver, + extension_registry, + .. + }, + metrics_reporter, + ) => { + let mut effect_handler = shared::EffectHandler::new(node_id, metrics_reporter); + if let Some(registry) = extension_registry { + effect_handler.set_extension_registry(registry); + } + effect_handler + .core + .set_pipeline_ctrl_msg_sender(pipeline_ctrl_msg_tx); + let message_channel = shared::MessageChannel::new(control_receiver); + extension.start(message_channel, effect_handler).await + } + } + } + + /// Takes the extension traits from this wrapper, leaving `None` in its place. + /// + /// This is called during pipeline initialization to collect all extension traits + /// into the central registry. + pub fn take_extension_traits(&mut self) -> Option { + match self { + ExtensionWrapper::Local { + extension_traits, .. + } => extension_traits.take(), + ExtensionWrapper::Shared { + extension_traits, .. + } => extension_traits.take(), + } + } +} + +#[async_trait::async_trait(?Send)] +impl Node for ExtensionWrapper { + fn is_shared(&self) -> bool { + match self { + ExtensionWrapper::Local { .. } => false, + ExtensionWrapper::Shared { .. } => true, + } + } + + fn node_id(&self) -> NodeId { + match self { + ExtensionWrapper::Local { node_id, .. } => node_id.clone(), + ExtensionWrapper::Shared { node_id, .. } => node_id.clone(), + } + } + + fn user_config(&self) -> Arc { + match self { + ExtensionWrapper::Local { + user_config: config, + .. + } => config.clone(), + ExtensionWrapper::Shared { + user_config: config, + .. + } => config.clone(), + } + } + + /// Sends a control message to the node. + async fn send_control_msg( + &self, + msg: NodeControlMsg, + ) -> Result<(), SendError>> { + match self { + ExtensionWrapper::Local { control_sender, .. } => control_sender.send(msg).await, + ExtensionWrapper::Shared { control_sender, .. } => control_sender.send(msg).await, + } + } +} + +#[cfg(test)] +mod tests { + use crate::config::ExtensionConfig; + use crate::control::NodeControlMsg; + use crate::extension::{Error, ExtensionWrapper}; + use crate::extensions::registry::ExtensionTraits; + use crate::local::extension as local; + use crate::message; + use crate::message::Message; + use crate::node::Node; + use crate::shared::extension as shared; + use crate::terminal_state::TerminalState; + use crate::testing::{CtrlMsgCounters, TestMsg, test_node}; + use async_trait::async_trait; + use otap_df_config::node::{NodeKind, NodeUserConfig}; + use serde_json::Value; + use std::sync::Arc; + + /// A test extension that counts message events. + pub struct TestExtension { + /// Counter for different message types + pub counter: CtrlMsgCounters, + } + + impl TestExtension { + /// Creates a new test extension with the given counter + pub fn new(counter: CtrlMsgCounters) -> Self { + TestExtension { counter } + } + } + + #[async_trait(?Send)] + impl local::Extension for TestExtension { + async fn start( + self: Box, + mut msg_chan: message::MessageChannel, + _effect_handler: local::EffectHandler, + ) -> Result { + // Loop until a Shutdown event is received. + loop { + match msg_chan.recv().await? { + Message::Control(NodeControlMsg::TimerTick { .. }) => { + self.counter.increment_timer_tick(); + } + Message::Control(NodeControlMsg::Config { .. }) => { + self.counter.increment_config(); + } + Message::Control(NodeControlMsg::Shutdown { .. }) => { + self.counter.increment_shutdown(); + break; + } + Message::Control(NodeControlMsg::CollectTelemetry { .. }) => { + // Ignore telemetry collection requests in tests + } + Message::Control(NodeControlMsg::Ack(_)) => {} + Message::Control(NodeControlMsg::Nack(_)) => {} + Message::Control(NodeControlMsg::DelayedData { .. }) => {} + Message::PData(_) => { + // Extensions don't process pdata + } + } + } + Ok(TerminalState::default()) + } + } + + #[test] + fn test_local_extension_wrapper_creation() { + let counter = CtrlMsgCounters::new(); + let extension = TestExtension::new(counter); + let node_id = test_node("test_extension"); + let user_config = Arc::new(NodeUserConfig::with_user_config( + NodeKind::Receiver, // Extension is not a config kind yet + "urn:test:extension".into(), + Value::Null, + )); + let config = ExtensionConfig::new("test_extension"); + + let wrapper = ExtensionWrapper::local( + extension, + ExtensionTraits::new(), + node_id, + user_config, + &config, + ); + + assert!(!wrapper.is_shared()); + } + + /// A shared test extension + pub struct SharedTestExtension { + pub counter: CtrlMsgCounters, + } + + impl SharedTestExtension { + pub fn new(counter: CtrlMsgCounters) -> Self { + SharedTestExtension { counter } + } + } + + #[async_trait] + impl shared::Extension for SharedTestExtension { + async fn start( + self: Box, + mut msg_chan: shared::MessageChannel, + _effect_handler: shared::EffectHandler, + ) -> Result { + loop { + match msg_chan.recv().await? { + Message::Control(NodeControlMsg::TimerTick { .. }) => { + self.counter.increment_timer_tick(); + } + Message::Control(NodeControlMsg::Config { .. }) => { + self.counter.increment_config(); + } + Message::Control(NodeControlMsg::Shutdown { .. }) => { + self.counter.increment_shutdown(); + break; + } + Message::Control(NodeControlMsg::CollectTelemetry { .. }) => {} + Message::Control(NodeControlMsg::Ack(_)) => {} + Message::Control(NodeControlMsg::Nack(_)) => {} + Message::Control(NodeControlMsg::DelayedData { .. }) => {} + Message::PData(_) => {} + } + } + Ok(TerminalState::default()) + } + } + + #[test] + fn test_shared_extension_wrapper_creation() { + let counter = CtrlMsgCounters::new(); + let extension = SharedTestExtension::new(counter); + let node_id = test_node("test_extension"); + let user_config = Arc::new(NodeUserConfig::with_user_config( + NodeKind::Receiver, // Extension is not a config kind yet + "urn:test:extension".into(), + Value::Null, + )); + let config = ExtensionConfig::new("test_extension"); + + let wrapper = ExtensionWrapper::shared( + extension, + ExtensionTraits::new(), + node_id, + user_config, + &config, + ); + + assert!(wrapper.is_shared()); + } +} diff --git a/rust/otap-dataflow/crates/engine/src/extensions/bearer_token_provider.rs b/rust/otap-dataflow/crates/engine/src/extensions/bearer_token_provider.rs new file mode 100644 index 0000000000..6c24e5a7ad --- /dev/null +++ b/rust/otap-dataflow/crates/engine/src/extensions/bearer_token_provider.rs @@ -0,0 +1,165 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Token provider extension trait. + +use async_trait::async_trait; +use std::borrow::Cow; + +/// Represents a secret value that should not be exposed in logs or debug output. +/// +/// The [`Debug`] implementation will not print the actual secret value. +#[derive(Clone, Eq)] +pub struct Secret(Cow<'static, str>); + +impl Secret { + /// Creates a new `Secret`. + pub fn new(value: T) -> Self + where + T: Into>, + { + Self(value.into()) + } + + /// Returns the secret value. + #[must_use] + pub fn secret(&self) -> &str { + &self.0 + } +} + +// Constant-time comparison to prevent timing attacks. +// Note: LLVM may optimize this in unexpected ways. +impl PartialEq for Secret { + fn eq(&self, other: &Self) -> bool { + let a = self.secret(); + let b = other.secret(); + + if a.len() != b.len() { + return false; + } + + a.bytes() + .zip(b.bytes()) + .fold(0, |acc, (a, b)| acc | (a ^ b)) + == 0 + } +} + +impl From for Secret { + fn from(value: String) -> Self { + Self::new(value) + } +} + +impl From<&'static str> for Secret { + fn from(value: &'static str) -> Self { + Self::new(value) + } +} + +impl std::fmt::Debug for Secret { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Secret") + } +} + +/// Represents a bearer token with its expiration time. +/// +/// The token value is wrapped in [`Secret`] to prevent accidental exposure +/// in logs or debug output. +#[derive(Debug, Clone)] +pub struct BearerToken { + /// The token value. + pub token: Secret, + + /// The expiration time as a UNIX timestamp (seconds since epoch). + pub expires_on: i64, +} + +impl BearerToken { + /// Creates a new bearer token. + pub fn new(token: T, expires_on: i64) -> Self + where + T: Into, + { + Self { + token: token.into(), + expires_on, + } + } +} + +/// A trait for components that can provide bearer authentication tokens. +/// +/// Extensions implementing this trait can be looked up by other components +/// (e.g., exporters) to obtain tokens for authentication. +/// +/// # Thread Safety +/// +/// - The returned future is `Send` for use with async runtimes like tokio +/// - The error type is `Send + Sync` for safe propagation across threads +/// +/// # Subscribing to Token Refresh Events +/// +/// Use [`subscribe_token_refresh`](BearerTokenProvider::subscribe_token_refresh) to receive notifications when +/// tokens are refreshed. This is useful for updating HTTP headers or other +/// authentication state without polling. +/// +/// # Implementing This Trait +/// +/// External crates can implement this trait on their extension types: +/// +/// ```ignore +/// use async_trait::async_trait; +/// use otap_df_engine::extensions::{BearerToken, BearerTokenProvider, Error}; +/// +/// struct MyAuthExtension { /* ... */ } +/// +/// #[async_trait] +/// impl BearerTokenProvider for MyAuthExtension { +/// async fn get_token(&self) -> Result { +/// // ... acquire token ... +/// Ok(BearerToken { token: "...".into(), expires_on: 0 }) +/// } +/// +/// fn subscribe_token_refresh(&self) -> tokio::sync::watch::Receiver> { +/// self.token_sender.subscribe() +/// } +/// } +/// ``` +#[async_trait] +pub trait BearerTokenProvider: Send { + /// Returns an authentication token. + /// + /// # Errors + /// + /// Returns an error if the token cannot be obtained. + async fn get_token(&self) -> Result; + + /// Subscribes to token refresh events. + /// + /// Returns a new receiver that will be notified whenever the token + /// is refreshed. Each call creates an independent subscription. + /// The receiver always contains the latest token value (or `None` + /// if no token has been acquired yet). + /// + /// # Example + /// + /// ```ignore + /// let auth = effect_handler.get_extension::("auth")?; + /// let mut token_rx = auth.subscribe_token_refresh(); + /// + /// loop { + /// tokio::select! { + /// _ = token_rx.changed() => { + /// if let Some(token) = token_rx.borrow().as_ref() { + /// // Update headers, etc. + /// } + /// } + /// // ... other branches + /// } + /// } + /// ``` + fn subscribe_token_refresh(&self) -> tokio::sync::watch::Receiver>; +} diff --git a/rust/otap-dataflow/crates/engine/src/extensions/mod.rs b/rust/otap-dataflow/crates/engine/src/extensions/mod.rs new file mode 100644 index 0000000000..b2bd28b82a --- /dev/null +++ b/rust/otap-dataflow/crates/engine/src/extensions/mod.rs @@ -0,0 +1,65 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Extension traits and registry for capability-based lookups. +//! +//! This module provides: +//! - [`ExtensionTraits`](registry::ExtensionTraits) - Cast functions for an extension's traits +//! - [`ExtensionRegistry`](registry::ExtensionRegistry) - A registry to look up extension traits by name +//! - Common extension traits like [`BearerTokenProvider`](bearer_token_provider::BearerTokenProvider) +//! +//! # Adding New Extension Traits +//! +//! New extension traits must be defined in this module. External crates can implement +//! existing extension traits on their types, but cannot define new extension trait types. +//! +//! This restriction is enforced at compile time via the sealed trait pattern. + +pub mod registry; + +// Re-export commonly used types +pub use registry::{ + CastFn, ExtensionError, ExtensionRegistry, ExtensionRegistryBuilder, ExtensionTraits, TraitId, + raw_to_trait_ref, trait_ref_to_raw, +}; + +/// Extension traits that components can implement to expose capabilities. +pub mod bearer_token_provider; + +// Private module for sealing - external crates cannot implement Sealed +mod private { + pub trait Sealed {} +} + +/// Marker trait for extension trait types that can be stored in [`ExtensionBundle`]. +/// +/// This trait is **sealed** - it can only be implemented for `dyn` extension traits +/// defined in this module. External crates cannot add new extension trait types, +/// but they CAN implement existing traits like [`BearerTokenProvider`] on their types. +/// +/// # How It Works +/// +/// - `ExtensionTrait` is implemented for `dyn BearerTokenProvider` (and other extension traits) +/// - External crates can `impl BearerTokenProvider for MyType` freely +/// - External crates CANNOT create new traits usable with `ExtensionBundle` +/// +/// This ensures the extension system only supports well-defined, documented capabilities. +/// +/// # Thread Safety +/// +/// Extension traits only require `Send`, not `Sync`. The caster-based registry +/// stores boxed instances and returns borrowed references, avoiding the Arc/Rc +/// requirement that would force `Sync` on trait objects. +pub trait ExtensionTrait: private::Sealed + Send {} + +// Implement ExtensionTrait for each extension trait's dyn type. +// This is the ONLY place where ExtensionTrait can be implemented. +impl private::Sealed for dyn BearerTokenProvider {} +impl ExtensionTrait for dyn BearerTokenProvider {} + +/// Error type for extension operations. +/// +/// Thread-safe error type compatible with any `thiserror`-derived error. +pub type Error = Box; + +pub use bearer_token_provider::{BearerToken, BearerTokenProvider, Secret}; diff --git a/rust/otap-dataflow/crates/engine/src/extensions/registry.rs b/rust/otap-dataflow/crates/engine/src/extensions/registry.rs new file mode 100644 index 0000000000..70ca91e78c --- /dev/null +++ b/rust/otap-dataflow/crates/engine/src/extensions/registry.rs @@ -0,0 +1,570 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Extension registry for storing and retrieving extension trait implementations by name. +//! +//! This registry uses a caster-based approach where extensions store a single boxed +//! instance and cast functions for each trait they implement. This avoids Arc/Rc +//! and the associated Send+Sync requirements on trait objects. +//! +//! # Example +//! +//! ```ignore +//! // An extension registers its capabilities using the macro: +//! let instance = AzureIdentityAuthExtension::new(...); +//! let casters = extension_traits!(AzureIdentityAuthExtension => BearerTokenProvider); +//! +//! // Pass to ExtensionWrapper which builds the registry entry: +//! ExtensionWrapper::local(instance, casters, node_id, config, ...); +//! +//! // A consumer retrieves a capability by trait (returns a reference): +//! let token_provider: &dyn BearerTokenProvider = registry +//! .get_trait::("azure_auth")?; +//! ``` + +// Allow unsafe code in this module for fat pointer transmutation. +// The safety invariants are documented and upheld by the implementation. +#![allow(unsafe_code)] + +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::sync::Arc; + +/// A cast function: downcasts `&dyn Any` → `&ConcreteType` → `&dyn Trait`, +/// then returns the fat pointer as `[usize; 2]` for type-erased storage. +/// Returns None if the downcast fails. +pub type CastFn = fn(&dyn Any) -> Option<[usize; 2]>; + +/// Reconstruct a `&dyn Trait` from a `[usize; 2]` fat pointer. +/// +/// # Safety +/// The caller must ensure `fat` was produced by `trait_ref_to_raw` with the +/// same `Trait` type, and that the underlying data is still alive. +#[inline] +pub unsafe fn raw_to_trait_ref<'a, T: ?Sized + 'a>(fat: [usize; 2]) -> &'a T { + // SAFETY: The caller guarantees fat was produced by trait_ref_to_raw with the same T. + unsafe { std::mem::transmute_copy(&fat) } +} + +/// Convert a `&dyn Trait` fat pointer into `[usize; 2]` for storage. +/// +/// # Safety +/// Relies on the standard Rust fat-pointer layout: `[data_ptr, vtable_ptr]`. +#[inline] +pub unsafe fn trait_ref_to_raw(r: &T) -> [usize; 2] { + // SAFETY: Fat pointer layout is stable for trait objects. + unsafe { std::mem::transmute_copy(&r) } +} + +/// Marker trait for TypeId lookup of trait types. +/// Used to get a stable TypeId for `dyn Trait` types. +pub trait TraitId {} + +/// Error when retrieving an extension trait. +#[derive(Debug)] +pub enum ExtensionError { + /// Extension not found by name. + NotFound { + /// The name of the extension that was not found. + name: String, + }, + /// Extension found but doesn't implement the requested trait. + TraitNotImplemented { + /// The name of the extension. + name: String, + /// The expected trait name. + expected: &'static str, + }, +} + +impl std::fmt::Display for ExtensionError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ExtensionError::NotFound { name } => { + write!(f, "extension '{}' not found", name) + } + ExtensionError::TraitNotImplemented { name, expected } => { + write!( + f, + "extension '{}' does not implement trait {}", + name, expected + ) + } + } + } +} + +impl std::error::Error for ExtensionError {} + +/// Cast functions for an extension's trait implementations. +/// +/// This is the return type of the [`extension_traits!`] macro. It contains +/// the mapping from trait TypeIds to cast functions that can convert +/// `&dyn Any` to `&dyn Trait`. +/// +/// # Example +/// +/// ```ignore +/// use otap_df_engine::extension_traits; +/// use otap_df_engine::extensions::BearerTokenProvider; +/// +/// struct MyAuthExtension { /* ... */ } +/// impl BearerTokenProvider for MyAuthExtension { /* ... */ } +/// +/// let casters = extension_traits!(MyAuthExtension => BearerTokenProvider); +/// ``` +#[derive(Default)] +pub struct ExtensionTraits { + casters: HashMap, +} + +impl ExtensionTraits { + /// Create a new empty casters collection. + #[must_use] + pub fn new() -> Self { + Self { + casters: HashMap::new(), + } + } + + /// Create from a raw HashMap (used by the macro). + #[must_use] + pub fn from_map(casters: HashMap) -> Self { + Self { casters } + } + + /// Returns the inner HashMap. + #[must_use] + pub fn into_inner(self) -> HashMap { + self.casters + } + + /// Check if a trait is registered. + #[must_use] + pub fn contains(&self) -> bool { + self.casters.contains_key(&TypeId::of::>()) + } + + /// Returns true if no traits are registered. + #[must_use] + pub fn is_empty(&self) -> bool { + self.casters.is_empty() + } + + /// Returns the number of registered traits. + #[must_use] + pub fn len(&self) -> usize { + self.casters.len() + } +} + +impl std::fmt::Debug for ExtensionTraits { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionTraits") + .field("trait_count", &self.casters.len()) + .finish() + } +} + +/// Macro to generate cast functions for an extension's trait implementations. +/// +/// This macro generates a mapping from trait TypeIds to cast functions that +/// can convert `&dyn Any` to `&dyn Trait`. No Arc/Rc cloning is involved. +/// +/// # Arguments +/// +/// * First: The concrete type name (needed for downcast) +/// * After `=>`: Comma-separated list of trait names this type implements +/// +/// Returns an [`ExtensionTraits`] that can be passed to `ExtensionWrapper::local()`. +/// +/// # Type Safety +/// +/// Only traits that implement [`crate::extensions::ExtensionTrait`] can be used +/// with this macro. This is enforced at compile time - attempting to use an +/// arbitrary trait will result in a compilation error. The macro also verifies +/// that the concrete type implements each specified trait. +/// +/// # Example +/// +/// ```ignore +/// use otap_df_engine::extension_traits; +/// use otap_df_engine::extensions::BearerTokenProvider; +/// +/// struct MyAuthExtension { /* ... */ } +/// impl BearerTokenProvider for MyAuthExtension { /* ... */ } +/// +/// let instance = MyAuthExtension { /* ... */ }; +/// let traits = extension_traits!(MyAuthExtension => BearerTokenProvider); +/// +/// ExtensionWrapper::local(instance, traits, node_id, user_config, config); +/// ``` +#[macro_export] +macro_rules! extension_traits { + ($concrete_ty:ty => $($trait:ident),* $(,)?) => {{ + #[allow(unused_mut)] + let mut casters: std::collections::HashMap< + std::any::TypeId, + $crate::extensions::registry::CastFn + > = std::collections::HashMap::new(); + $( + { + // Compile-time check: ensure the trait is a valid ExtensionTrait. + // This prevents using arbitrary traits with this macro. + const _: fn() = || { + fn assert_extension_trait() {} + assert_extension_trait::(); + }; + + // Inner fn is monomorphic — $concrete_ty is substituted by the macro, + // so there are no captures and this coerces to a fn pointer. + fn __cast(any: &dyn std::any::Any) -> Option<[usize; 2]> { + let concrete = any.downcast_ref::<$concrete_ty>()?; + let trait_ref: &dyn $trait = concrete; + // SAFETY: We're converting a valid trait reference to its raw representation + Some(unsafe { $crate::extensions::registry::trait_ref_to_raw(trait_ref) }) + } + let _ = casters.insert( + std::any::TypeId::of::>(), + __cast as $crate::extensions::registry::CastFn, + ); + } + )* + $crate::extensions::registry::ExtensionTraits::from_map(casters) + }}; +} + +/// Internal storage for an extension instance and its casters. +/// +/// This is used internally by the registry to store extensions. +/// Users should not create this directly - use [`extension_traits!`] macro +/// with `ExtensionWrapper::local()` or `::shared()`. +pub struct ExtensionEntry { + /// The single concrete instance, type-erased. + instance: Box, + /// One cast function per registered trait. + casters: HashMap, +} + +impl ExtensionEntry { + /// Create a new entry from an instance and casters. + pub fn new(instance: T, casters: ExtensionTraits) -> Self { + Self { + instance: Box::new(instance), + casters: casters.into_inner(), + } + } + + /// Get a trait reference from the entry. + #[must_use] + pub fn get(&self) -> Option<&T> { + let cast = self.casters.get(&TypeId::of::>())?; + let fat = cast(self.instance.as_ref())?; + // SAFETY: `fat` was produced by `trait_ref_to_raw::` from a valid + // `&T` derived from the boxed instance. The box is alive for `&self`. + Some(unsafe { raw_to_trait_ref(fat) }) + } + + /// Check if the entry contains a trait implementation. + #[must_use] + pub fn contains(&self) -> bool { + self.casters.contains_key(&TypeId::of::>()) + } + + /// Returns the number of trait implementations. + #[must_use] + pub fn len(&self) -> usize { + self.casters.len() + } + + /// Returns true if no traits are registered. + #[must_use] + pub fn is_empty(&self) -> bool { + self.casters.is_empty() + } +} + +// ExtensionEntry is Send because: +// - instance: Box is Send +// - casters: HashMap - TypeId is Send+Sync, fn pointers are Send+Sync +// +// ExtensionEntry is Sync because: +// - The entry is immutable after construction +// - get() only returns shared references +unsafe impl Sync for ExtensionEntry {} + +impl std::fmt::Debug for ExtensionEntry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionEntry") + .field("trait_count", &self.casters.len()) + .finish() + } +} + +/// Registry for extension trait implementations. +/// +/// Extensions register themselves here during creation so other components +/// can look them up by name and retrieve trait references. +/// +/// The registry wraps entries in an `Arc` so it can be cheaply cloned +/// (e.g., when cloning effect handlers). Callers receive borrowed +/// `&dyn Trait` references tied to the registry's lifetime. +#[derive(Default, Clone)] +pub struct ExtensionRegistry { + extensions: Arc>, +} + +impl ExtensionRegistry { + /// Create a new empty registry. + #[must_use] + pub fn new() -> Self { + Self { + extensions: Arc::new(HashMap::new()), + } + } + + /// Create a registry from a map of extension entries. + #[must_use] + pub fn from_map(extensions: HashMap) -> Self { + Self { + extensions: Arc::new(extensions), + } + } + + /// Get a trait reference by extension name. + /// + /// Returns a borrowed `&dyn Trait` tied to the registry's lifetime. + /// + /// # Type Parameters + /// + /// * `T` - The trait type (e.g., `dyn BearerTokenProvider`). + /// + /// # Errors + /// + /// Returns `ExtensionError::NotFound` if no extension with that name exists. + /// Returns `ExtensionError::TraitNotImplemented` if the extension doesn't implement the trait. + /// + /// # Example + /// + /// ```ignore + /// let token_provider: &dyn BearerTokenProvider = registry + /// .get_trait::("azure_auth")?; + /// ``` + pub fn get_trait(&self, name: &str) -> Result<&T, ExtensionError> { + let entry = self + .extensions + .get(name) + .ok_or_else(|| ExtensionError::NotFound { + name: name.to_string(), + })?; + + entry.get::().ok_or_else(|| ExtensionError::TraitNotImplemented { + name: name.to_string(), + expected: std::any::type_name::(), + }) + } + + /// Check if an extension exists by name. + #[must_use] + pub fn contains(&self, name: &str) -> bool { + self.extensions.contains_key(name) + } + + /// Returns the number of registered extensions. + #[must_use] + pub fn len(&self) -> usize { + self.extensions.len() + } + + /// Returns true if no extensions are registered. + #[must_use] + pub fn is_empty(&self) -> bool { + self.extensions.is_empty() + } + + /// Returns an iterator over extension names. + pub fn names(&self) -> impl Iterator { + self.extensions.keys() + } +} + +impl std::fmt::Debug for ExtensionRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionRegistry") + .field("extensions", &self.extensions.keys().collect::>()) + .finish() + } +} + +/// Builder for constructing an [`ExtensionRegistry`]. +/// +/// Use this to register extension entries before creating the immutable registry. +/// +/// # Example +/// +/// ```ignore +/// let mut builder = ExtensionRegistryBuilder::new(); +/// +/// let auth = AzureIdentityAuthExtension::new(...); +/// let casters = extension_traits!(AzureIdentityAuthExtension => BearerTokenProvider); +/// builder.register("azure_auth", auth, casters); +/// +/// let registry = builder.build(); +/// ``` +#[derive(Default)] +pub struct ExtensionRegistryBuilder { + /// The map of extension names to entries being built. + pub extensions: HashMap, +} + +impl ExtensionRegistryBuilder { + /// Create a new empty builder. + #[must_use] + pub fn new() -> Self { + Self { + extensions: HashMap::new(), + } + } + + /// Register an extension with a name, instance, and casters. + pub fn register( + &mut self, + name: String, + instance: T, + casters: ExtensionTraits, + ) { + let _ = self.extensions.insert(name, ExtensionEntry::new(instance, casters)); + } + + /// Build the immutable registry. + #[must_use] + pub fn build(self) -> ExtensionRegistry { + ExtensionRegistry::from_map(self.extensions) + } +} + +impl std::fmt::Debug for ExtensionRegistryBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionRegistryBuilder") + .field("extensions", &self.extensions.keys().collect::>()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::extensions::BearerToken; + use crate::extensions::BearerTokenProvider; + use tokio::sync::watch; + + struct TestTokenProvider { + token: String, + } + + #[async_trait::async_trait] + impl BearerTokenProvider for TestTokenProvider { + async fn get_token(&self) -> Result { + Ok(BearerToken::new(self.token.clone(), 0)) + } + + fn subscribe_token_refresh(&self) -> watch::Receiver> { + let (tx, rx) = watch::channel(None); + drop(tx); + rx + } + } + + #[test] + fn test_extension_casters() { + let casters = crate::extension_traits!(TestTokenProvider => BearerTokenProvider); + assert_eq!(casters.len(), 1); + assert!(casters.contains::()); + } + + #[test] + fn test_extension_entry() { + let instance = TestTokenProvider { + token: "test_token".to_string(), + }; + let casters = crate::extension_traits!(TestTokenProvider => BearerTokenProvider); + let entry = ExtensionEntry::new(instance, casters); + + assert_eq!(entry.len(), 1); + assert!(entry.contains::()); + + let token_provider: &dyn BearerTokenProvider = entry.get().unwrap(); + drop(token_provider); + } + + #[test] + fn test_registry_get_trait() { + let instance = TestTokenProvider { + token: "test_token".to_string(), + }; + let casters = crate::extension_traits!(TestTokenProvider => BearerTokenProvider); + let entry = ExtensionEntry::new(instance, casters); + + let mut map = HashMap::new(); + let _ = map.insert("test_ext".to_string(), entry); + + let registry = ExtensionRegistry::from_map(map); + + let result: Result<&dyn BearerTokenProvider, _> = registry.get_trait("test_ext"); + assert!(result.is_ok()); + + let not_found: Result<&dyn BearerTokenProvider, _> = registry.get_trait("missing"); + assert!(matches!(not_found, Err(ExtensionError::NotFound { .. }))); + } + + #[test] + fn test_registry_builder() { + let mut builder = ExtensionRegistryBuilder::new(); + assert!(builder.extensions.is_empty()); + + let instance = TestTokenProvider { + token: "builder_test".to_string(), + }; + let casters = crate::extension_traits!(TestTokenProvider => BearerTokenProvider); + + builder.register("my_extension".to_string(), instance, casters); + + let registry = builder.build(); + assert_eq!(registry.len(), 1); + assert!(registry.contains("my_extension")); + let _: &dyn BearerTokenProvider = registry.get_trait("my_extension").unwrap(); + } + + #[test] + fn test_extension_error_display() { + let not_found = ExtensionError::NotFound { + name: "missing_ext".to_string(), + }; + let display = format!("{}", not_found); + assert!(display.contains("missing_ext")); + assert!(display.contains("not found")); + + let not_impl = ExtensionError::TraitNotImplemented { + name: "my_ext".to_string(), + expected: "BearerTokenProvider", + }; + let display = format!("{}", not_impl); + assert!(display.contains("my_ext")); + assert!(display.contains("BearerTokenProvider")); + } + + #[test] + fn test_registry_debug() { + let instance = TestTokenProvider { + token: "test".to_string(), + }; + let casters = crate::extension_traits!(TestTokenProvider => BearerTokenProvider); + let entry = ExtensionEntry::new(instance, casters); + + let registry = + ExtensionRegistry::from_map(HashMap::from([("test_ext".to_string(), entry)])); + let debug_str = format!("{:?}", registry); + assert!(debug_str.contains("ExtensionRegistry")); + assert!(debug_str.contains("test_ext")); + } +} diff --git a/rust/otap-dataflow/crates/engine/src/lib.rs b/rust/otap-dataflow/crates/engine/src/lib.rs index 1e222e157d..111c6251f4 100644 --- a/rust/otap-dataflow/crates/engine/src/lib.rs +++ b/rust/otap-dataflow/crates/engine/src/lib.rs @@ -10,11 +10,12 @@ use crate::{ CHANNEL_MODE_LOCAL, CHANNEL_MODE_SHARED, CHANNEL_TYPE_MPMC, CHANNEL_TYPE_MPSC, ChannelMetricsRegistry, ChannelReceiverMetrics, ChannelSenderMetrics, }, - config::{ExporterConfig, ProcessorConfig, ReceiverConfig}, + config::{ExporterConfig, ExtensionConfig, ProcessorConfig, ReceiverConfig}, control::{AckMsg, CallData, NackMsg}, entity_context::{NodeTelemetryGuard, NodeTelemetryHandle, with_node_telemetry_handle}, error::{Error, TypedError}, exporter::ExporterWrapper, + extension::ExtensionWrapper, local::message::{LocalReceiver, LocalSender}, message::{Receiver, Sender}, node::{Node, NodeDefs, NodeId, NodeName, NodeType}, @@ -42,6 +43,8 @@ use std::{collections::HashMap, sync::OnceLock}; pub mod error; pub mod exporter; +pub mod extension; +pub mod extensions; pub mod message; pub mod processor; pub mod receiver; @@ -159,6 +162,35 @@ impl NamedFactory for ExporterFactory { } } +/// A factory for creating extensions. +pub struct ExtensionFactory { + /// The name of the extension. + pub name: &'static str, + /// A function that creates a new extension instance. + pub create: fn( + pipeline: PipelineContext, + node: NodeId, + node_config: Arc, + extension_config: &ExtensionConfig, + ) -> Result, otap_df_config::error::Error>, +} + +// Note: We don't use `#[derive(Clone)]` here to avoid forcing the `PData` type to implement `Clone`. +impl Clone for ExtensionFactory { + fn clone(&self) -> Self { + ExtensionFactory { + name: self.name, + create: self.create, + } + } +} + +impl NamedFactory for ExtensionFactory { + fn name(&self) -> &'static str { + self.name + } +} + /// Returns a map of factory names to factory instances. pub fn get_factory_map( factory_map: &'static OnceLock>, @@ -283,15 +315,17 @@ pub const fn build_factory() -> PipelineFactory { /// A pipeline factory. /// -/// This factory contains a registry of all the micro-factories for receivers, processors, and -/// exporters, as well as the logic for creating pipelines based on a given configuration. +/// This factory contains a registry of all the micro-factories for receivers, processors, +/// exporters, and extensions, as well as the logic for creating pipelines based on a given configuration. pub struct PipelineFactory { receiver_factory_map: OnceLock>>, processor_factory_map: OnceLock>>, exporter_factory_map: OnceLock>>, + extension_factory_map: OnceLock>>, receiver_factories: &'static [ReceiverFactory], processor_factories: &'static [ProcessorFactory], exporter_factories: &'static [ExporterFactory], + extension_factories: &'static [ExtensionFactory], } impl PipelineFactory { @@ -306,9 +340,31 @@ impl PipelineFactory { receiver_factory_map: OnceLock::new(), processor_factory_map: OnceLock::new(), exporter_factory_map: OnceLock::new(), + extension_factory_map: OnceLock::new(), + receiver_factories, + processor_factories, + exporter_factories, + extension_factories: &[], + } + } + + /// Creates a new factory registry with extension factories included. + #[must_use] + pub const fn with_extensions( + receiver_factories: &'static [ReceiverFactory], + processor_factories: &'static [ProcessorFactory], + exporter_factories: &'static [ExporterFactory], + extension_factories: &'static [ExtensionFactory], + ) -> Self { + Self { + receiver_factory_map: OnceLock::new(), + processor_factory_map: OnceLock::new(), + exporter_factory_map: OnceLock::new(), + extension_factory_map: OnceLock::new(), receiver_factories, processor_factories, exporter_factories, + extension_factories, } } @@ -342,6 +398,16 @@ impl PipelineFactory { }) } + /// Gets the extension factory map, initializing it if necessary. + pub fn get_extension_factory_map(&self) -> &HashMap<&'static str, ExtensionFactory> { + self.extension_factory_map.get_or_init(|| { + self.extension_factories + .iter() + .map(|f| (f.name(), f.clone())) + .collect::>>() + }) + } + /// Builds a runtime pipeline from the given pipeline configuration. /// /// Main phases: @@ -365,6 +431,7 @@ impl PipelineFactory { let mut receivers = Vec::new(); let mut processors = Vec::new(); let mut exporters = Vec::new(); + let mut extensions = Vec::new(); let mut build_state = BuildState::new(); let pipeline_group_id = pipeline_ctx.pipeline_group_id(); @@ -457,6 +524,29 @@ impl PipelineFactory { )?; exporters.push(wrapper); } + otap_df_config::node::NodeKind::Extension => { + let node_id = build_state.next_node_id( + name.clone(), + NodeType::Extension, + PipeNode::new(extensions.len()), + )?; + let node_id_for_create = node_id.clone(); + let wrapper = self.build_node_wrapper( + &mut build_state, + &base_ctx, + NodeType::Extension, + node_id, + channel_metrics_enabled, + || { + self.create_extension( + &base_ctx, + node_id_for_create, + node_config.clone(), + ) + }, + )?; + extensions.push(wrapper); + } otap_df_config::node::NodeKind::ProcessorChain => { // ToDo(LQ): Implement processor chain optimization to eliminate intermediary channels. return Err(Error::UnsupportedNodeKind { @@ -468,11 +558,38 @@ impl PipelineFactory { let edges = collect_hyper_edges_runtime(&receivers, &processors); + // Build extension registry from extension traits + // Note: Extension trait lookups through the registry are not yet implemented. + // The traits are collected but need an instance to work with. + // TODO: Refactor extension ownership so registry can provide trait references. + let registry_builder = extensions::ExtensionRegistryBuilder::new(); + for extension in &mut extensions { + let _name = extension.node_id().name.to_string(); + let _traits = extension.take_extension_traits(); + // registry_builder.register(name, ???, traits); + } + let extension_registry = registry_builder.build(); + + // Set extension registry on all wrappers + for receiver in &mut receivers { + receiver.set_extension_registry(extension_registry.clone()); + } + for processor in &mut processors { + processor.set_extension_registry(extension_registry.clone()); + } + for exporter in &mut exporters { + exporter.set_extension_registry(extension_registry.clone()); + } + for extension in &mut extensions { + extension.set_extension_registry(extension_registry.clone()); + } + // First pass: plan hyper-edge wiring to avoid multiple mutable borrows let buffer_size = NonZeroUsize::new(config.pipeline_settings().default_pdata_channel_size) .expect("default_pdata_channel_size must be non-zero"); let nodes = std::mem::take(&mut build_state.nodes); - let mut pipeline = RuntimePipeline::new(config, receivers, processors, exporters, nodes); + let mut pipeline = + RuntimePipeline::new(config, receivers, processors, exporters, extensions, nodes); let wirings = edges .into_iter() .map(|hyper_edge| { @@ -1167,6 +1284,55 @@ impl PipelineFactory { Ok(exporter) } + + /// Creates an extension node and adds it to the list of runtime nodes. + fn create_extension( + &self, + pipeline_ctx: &PipelineContext, + node_id: NodeId, + node_config: Arc, + ) -> Result, Error> { + let pipeline_group_id = pipeline_ctx.pipeline_group_id(); + let pipeline_id = pipeline_ctx.pipeline_id(); + let core_id = pipeline_ctx.core_id(); + let name = node_id.name.clone(); + + otel_debug!( + "extension.create.start", + pipeline_group_id = pipeline_group_id.as_ref(), + pipeline_id = pipeline_id.as_ref(), + core_id = core_id, + node_id = name.as_ref(), + ); + + let factory = self + .get_extension_factory_map() + .get(node_config.plugin_urn.as_ref()) + .ok_or_else(|| Error::UnknownExtension { + plugin_urn: node_config.plugin_urn.clone(), + })?; + let extension_config = ExtensionConfig::new(name.clone()); + let create = factory.create; + + let node_id_for_create = node_id.clone(); + let extension = create( + (*pipeline_ctx).clone(), + node_id_for_create, + node_config, + &extension_config, + ) + .map_err(|e| Error::ConfigError(Box::new(e)))?; + + otel_debug!( + "extension.create.complete", + pipeline_group_id = pipeline_group_id.as_ref(), + pipeline_id = pipeline_id.as_ref(), + core_id = core_id, + node_id = name.as_ref(), + ); + + Ok(extension) + } } trait TelemetryWrapped: Sized { @@ -1239,6 +1405,26 @@ impl TelemetryWrapped for ExporterWrapper { } } +impl TelemetryWrapped for ExtensionWrapper { + fn with_control_channel_metrics( + self, + pipeline_ctx: &PipelineContext, + channel_metrics: &mut ChannelMetricsRegistry, + channel_metrics_enabled: bool, + ) -> Self { + ExtensionWrapper::with_control_channel_metrics( + self, + pipeline_ctx, + channel_metrics, + channel_metrics_enabled, + ) + } + + fn with_node_telemetry_guard(self, guard: NodeTelemetryGuard) -> Self { + ExtensionWrapper::with_node_telemetry_guard(self, guard) + } +} + struct NodeRegistration { node_id: NodeId, node_type: NodeType, @@ -1282,6 +1468,7 @@ impl BuildState { NodeType::Receiver => Error::ReceiverAlreadyExists { receiver: node_id }, NodeType::Processor => Error::ProcessorAlreadyExists { processor: node_id }, NodeType::Exporter => Error::ExporterAlreadyExists { exporter: node_id }, + NodeType::Extension => Error::ExtensionAlreadyExists { extension: node_id }, }); } @@ -1315,7 +1502,10 @@ impl BuildState { let registration = self.registration(name)?; match registration.node_type { NodeType::Processor | NodeType::Exporter => Ok(registration.node_id.clone()), - NodeType::Receiver => Err(Error::UnknownNode { node: name.clone() }), + // Receivers and extensions don't receive pdata + NodeType::Receiver | NodeType::Extension => { + Err(Error::UnknownNode { node: name.clone() }) + } } } } diff --git a/rust/otap-dataflow/crates/engine/src/local/exporter.rs b/rust/otap-dataflow/crates/engine/src/local/exporter.rs index 91123430f8..68a162f8aa 100644 --- a/rust/otap-dataflow/crates/engine/src/local/exporter.rs +++ b/rust/otap-dataflow/crates/engine/src/local/exporter.rs @@ -36,6 +36,8 @@ use crate::control::{AckMsg, NackMsg}; use crate::effect_handler::{EffectHandlerCore, TelemetryTimerCancelHandle, TimerCancelHandle}; use crate::error::Error; +use crate::extensions::ExtensionTrait; +use crate::extensions::registry::ExtensionError; use crate::message::MessageChannel; use crate::node::NodeId; use crate::terminal_state::TerminalState; @@ -114,6 +116,39 @@ impl EffectHandler { self.core.node_id() } + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: crate::extensions::ExtensionRegistry) { + self.core.set_extension_registry(registry); + } + + /// Gets an extension trait implementation by extension name. + /// + /// This allows exporters to look up capabilities provided by extensions, + /// such as authentication tokens or credentials. + /// + /// # Type Parameters + /// + /// * `T` - The trait type (e.g., `dyn BearerTokenProvider`). Must implement `ExtensionTrait`. + /// + /// # Errors + /// + /// Returns `ExtensionError::NotFound` if no extension with that name exists. + /// Returns `ExtensionError::TraitNotImplemented` if the extension doesn't implement the trait. + /// + /// # Example + /// + /// ```ignore + /// let token_provider: &dyn BearerTokenProvider = effect_handler + /// .get_extension::("azure_auth")?; + /// let token = token_provider.get_token(); + /// ``` + pub fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + self.core.get_extension::(name) + } + /// Print an info message to stdout. /// /// This method provides a standardized way for exporters to output diff --git a/rust/otap-dataflow/crates/engine/src/local/extension.rs b/rust/otap-dataflow/crates/engine/src/local/extension.rs new file mode 100644 index 0000000000..a1987dda92 --- /dev/null +++ b/rust/otap-dataflow/crates/engine/src/local/extension.rs @@ -0,0 +1,177 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Trait and structures used to implement local extensions (!Send). +//! +//! An extension is a special component that doesn't process pipeline data (pdata). +//! Extensions provide auxiliary services to the pipeline, such as health checks, +//! service discovery, or configuration management. +//! +//! Unlike receivers, processors, and exporters, extensions do not participate in +//! the data flow - they only handle control messages and provide services. +//! +//! # Lifecycle +//! +//! 1. The extension is instantiated and configured +//! 2. The `start` method is called, which begins the extension's operation +//! 3. The extension processes internal control messages +//! 4. The extension shuts down when it receives a `Shutdown` control message or encounters a fatal +//! error +//! +//! # Thread Safety +//! +//! This implementation is designed to be used in a single-threaded environment. +//! The `Extension` trait does not require the `Send` bound, allowing for the use of non-thread-safe +//! types. + +use crate::control::{AckMsg, NackMsg}; +use crate::effect_handler::{EffectHandlerCore, TelemetryTimerCancelHandle, TimerCancelHandle}; +use crate::error::Error; +use crate::extensions::ExtensionTrait; +use crate::extensions::registry::ExtensionError; +use crate::message::MessageChannel; +use crate::node::NodeId; +use crate::terminal_state::TerminalState; +use async_trait::async_trait; +use otap_df_telemetry::error::Error as TelemetryError; +use otap_df_telemetry::metrics::{MetricSet, MetricSetHandler}; +use otap_df_telemetry::reporter::MetricsReporter; +use std::marker::PhantomData; +use std::time::Duration; + +/// A trait for extensions (!Send definition). +/// +/// Extensions are special components that don't process pipeline data. +/// They provide auxiliary services to the pipeline. +#[async_trait(?Send)] +pub trait Extension { + /// Starts the extension and begins its operation. + /// + /// The pipeline engine will call this function to start the extension in a separate task. + /// Extensions are assigned their own dedicated task at pipeline initialization. + /// + /// The extension is taken as `Box` so the method takes ownership of the extension once `start` is called. + /// This lets it move into an independent task, after which the pipeline can only + /// reach it through the control-message channel. + /// + /// Extensions process control messages only - they do not receive or send pipeline data. + /// + /// # Parameters + /// + /// - `msg_chan`: A channel to receive control messages only (no pdata). + /// - `effect_handler`: A handler to perform side effects. + /// + /// # Errors + /// + /// Returns an [`Error`] if an unrecoverable error occurs. + /// + /// # Cancellation Safety + /// + /// This method should be cancellation safe and clean up any resources when dropped. + async fn start( + self: Box, + msg_chan: MessageChannel, + effect_handler: EffectHandler, + ) -> Result; +} + +/// A `!Send` implementation of the EffectHandler for extensions. +#[derive(Clone)] +pub struct EffectHandler { + pub(crate) core: EffectHandlerCore, + _pd: PhantomData, +} + +impl EffectHandler { + /// Creates a new local (!Send) `EffectHandler` with the given extension node id and metrics + /// reporter. + #[must_use] + pub fn new(node_id: NodeId, metrics_reporter: MetricsReporter) -> Self { + EffectHandler { + core: EffectHandlerCore::new(node_id, metrics_reporter), + _pd: PhantomData, + } + } + + /// Returns the id of the extension associated with this handler. + #[must_use] + pub fn extension_id(&self) -> NodeId { + self.core.node_id() + } + + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: crate::extensions::ExtensionRegistry) { + self.core.set_extension_registry(registry); + } + + /// Gets an extension trait implementation by extension name. + /// + /// This allows extensions to look up capabilities provided by other extensions. + /// + /// # Type Parameters + /// + /// * `T` - The trait type (e.g., `dyn BearerTokenProvider`). Must implement `ExtensionTrait`. + /// + /// # Errors + /// + /// Returns `ExtensionError::NotFound` if no extension with that name exists. + /// Returns `ExtensionError::TraitNotImplemented` if the extension doesn't implement the trait. + pub fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + self.core.get_extension::(name) + } + + /// Print an info message to stdout. + /// + /// This method provides a standardized way for extensions to output + /// informational messages without blocking the async runtime. + pub async fn info(&self, message: &str) { + self.core.info(message).await; + } + + /// Starts a cancellable periodic timer that emits TimerTick on the control channel. + /// Returns a handle that can be used to cancel the timer. + /// + /// Current limitation: Only one timer can be started by an extension at a time. + pub async fn start_periodic_timer( + &self, + duration: Duration, + ) -> Result, Error> { + self.core.start_periodic_timer(duration).await + } + + /// Starts a cancellable periodic telemetry timer that emits CollectTelemetry. + pub async fn start_periodic_telemetry( + &self, + duration: Duration, + ) -> Result, Error> { + self.core.start_periodic_telemetry(duration).await + } + + /// Send an Ack to a node of known-interest. + pub async fn route_ack(&self, ack: AckMsg, cxf: F) -> Result<(), Error> + where + F: FnOnce(AckMsg) -> Option<(usize, AckMsg)>, + { + self.core.route_ack(ack, cxf).await + } + + /// Send a Nack to a node of known-interest. + pub async fn route_nack(&self, nack: NackMsg, cxf: F) -> Result<(), Error> + where + F: FnOnce(NackMsg) -> Option<(usize, NackMsg)>, + { + self.core.route_nack(nack, cxf).await + } + + /// Reports metrics collected by the extension. + #[allow(dead_code)] // Will be used in the future. ToDo report metrics from channel and messages. + pub(crate) fn report_metrics( + &mut self, + metrics: &mut MetricSet, + ) -> Result<(), TelemetryError> { + self.core.report_metrics(metrics) + } +} diff --git a/rust/otap-dataflow/crates/engine/src/local/mod.rs b/rust/otap-dataflow/crates/engine/src/local/mod.rs index b4bb3e477d..999778b260 100644 --- a/rust/otap-dataflow/crates/engine/src/local/mod.rs +++ b/rust/otap-dataflow/crates/engine/src/local/mod.rs @@ -1,9 +1,10 @@ // Copyright The OpenTelemetry Authors // SPDX-License-Identifier: Apache-2.0 -//! Traits and structs defining the local (!Send) version of receivers, processors, and exporters. +//! Traits and structs defining the local (!Send) version of receivers, processors, exporters, and extensions. pub mod exporter; +pub mod extension; pub mod message; pub mod processor; pub mod receiver; diff --git a/rust/otap-dataflow/crates/engine/src/local/processor.rs b/rust/otap-dataflow/crates/engine/src/local/processor.rs index e652d5694b..bb3dc643df 100644 --- a/rust/otap-dataflow/crates/engine/src/local/processor.rs +++ b/rust/otap-dataflow/crates/engine/src/local/processor.rs @@ -35,6 +35,8 @@ use crate::control::{AckMsg, NackMsg}; use crate::effect_handler::{EffectHandlerCore, TelemetryTimerCancelHandle, TimerCancelHandle}; use crate::error::{Error, TypedError}; +use crate::extensions::ExtensionTrait; +use crate::extensions::registry::ExtensionError; use crate::message::{Message, Sender}; use crate::node::NodeId; use async_trait::async_trait; @@ -132,6 +134,30 @@ impl EffectHandler { self.core.node_id() } + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: crate::extensions::ExtensionRegistry) { + self.core.set_extension_registry(registry); + } + + /// Gets an extension trait implementation by extension name. + /// + /// This allows processors to look up capabilities provided by extensions. + /// + /// # Type Parameters + /// + /// * `T` - The trait type (e.g., `dyn BearerTokenProvider`). Must implement `ExtensionTrait`. + /// + /// # Errors + /// + /// Returns `ExtensionError::NotFound` if no extension with that name exists. + /// Returns `ExtensionError::TraitNotImplemented` if the extension doesn't implement the trait. + pub fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + self.core.get_extension::(name) + } + /// Returns the list of connected out ports for this processor. #[must_use] pub fn connected_ports(&self) -> Vec { diff --git a/rust/otap-dataflow/crates/engine/src/local/receiver.rs b/rust/otap-dataflow/crates/engine/src/local/receiver.rs index b898aa433c..dd6f10f683 100644 --- a/rust/otap-dataflow/crates/engine/src/local/receiver.rs +++ b/rust/otap-dataflow/crates/engine/src/local/receiver.rs @@ -35,6 +35,7 @@ use crate::control::{NodeControlMsg, PipelineCtrlMsgSender}; use crate::effect_handler::{EffectHandlerCore, TelemetryTimerCancelHandle, TimerCancelHandle}; use crate::error::{Error, TypedError}; +use crate::extensions::{ExtensionError, ExtensionTrait}; use crate::message::Sender; use crate::node::NodeId; use crate::terminal_state::TerminalState; @@ -168,6 +169,23 @@ impl EffectHandler { self.core.node_id() } + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: crate::extensions::ExtensionRegistry) { + self.core.set_extension_registry(registry); + } + + /// Returns an extension trait implementation by name. + /// + /// # Errors + /// + /// Returns an [`ExtensionError`] if the extension is not found or doesn't implement the trait. + pub fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + self.core.get_extension::(name) + } + /// Returns the list of connected out ports for this receiver. #[must_use] pub fn connected_ports(&self) -> Vec { diff --git a/rust/otap-dataflow/crates/engine/src/node.rs b/rust/otap-dataflow/crates/engine/src/node.rs index 443764f5b0..b81550b2f7 100644 --- a/rust/otap-dataflow/crates/engine/src/node.rs +++ b/rust/otap-dataflow/crates/engine/src/node.rs @@ -58,6 +58,8 @@ pub enum NodeType { Processor, /// Represents a node that exports data to an external destination. Exporter, + /// Represents an extension node that provides auxiliary services (no pdata processing). + Extension, } /// Trait for nodes that can send pdata to a specific port. diff --git a/rust/otap-dataflow/crates/engine/src/processor.rs b/rust/otap-dataflow/crates/engine/src/processor.rs index de6ec4b783..e824642c58 100644 --- a/rust/otap-dataflow/crates/engine/src/processor.rs +++ b/rust/otap-dataflow/crates/engine/src/processor.rs @@ -14,6 +14,7 @@ use crate::context::PipelineContext; use crate::control::{Controllable, NodeControlMsg, PipelineCtrlMsgSender}; use crate::entity_context::NodeTelemetryGuard; use crate::error::{Error, ProcessorErrorKind}; +use crate::extensions::ExtensionRegistry; use crate::local::message::{LocalReceiver, LocalSender}; use crate::local::processor as local; use crate::message::{Message, MessageChannel, Receiver, Sender}; @@ -57,6 +58,8 @@ pub enum ProcessorWrapper { pdata_receiver: Option>, /// Telemetry guard for node lifecycle cleanup. telemetry: Option, + /// Extension registry for accessing extension traits. + extension_registry: Option, }, /// A processor with a `Send` implementation. Shared { @@ -79,6 +82,8 @@ pub enum ProcessorWrapper { pdata_receiver: Option>, /// Telemetry guard for node lifecycle cleanup. telemetry: Option, + /// Extension registry for accessing extension traits. + extension_registry: Option, }, } @@ -133,6 +138,7 @@ impl ProcessorWrapper { pdata_senders: HashMap::new(), pdata_receiver: None, telemetry: None, + extension_registry: None, } } @@ -160,6 +166,19 @@ impl ProcessorWrapper { pdata_senders: HashMap::new(), pdata_receiver: None, telemetry: None, + extension_registry: None, + } + } + + /// Sets the extension registry for this processor. + pub fn set_extension_registry(&mut self, registry: ExtensionRegistry) { + match self { + ProcessorWrapper::Local { + extension_registry, .. + } => *extension_registry = Some(registry), + ProcessorWrapper::Shared { + extension_registry, .. + } => *extension_registry = Some(registry), } } @@ -174,6 +193,7 @@ impl ProcessorWrapper { control_receiver, pdata_senders, pdata_receiver, + extension_registry, .. } => ProcessorWrapper::Local { node_id, @@ -185,6 +205,7 @@ impl ProcessorWrapper { pdata_senders, pdata_receiver, telemetry: Some(guard), + extension_registry, }, ProcessorWrapper::Shared { node_id, @@ -195,6 +216,7 @@ impl ProcessorWrapper { control_receiver, pdata_senders, pdata_receiver, + extension_registry, .. } => ProcessorWrapper::Shared { node_id, @@ -206,6 +228,7 @@ impl ProcessorWrapper { pdata_senders, pdata_receiver, telemetry: Some(guard), + extension_registry, }, } } @@ -234,7 +257,7 @@ impl ProcessorWrapper { pdata_senders, pdata_receiver, telemetry, - .. + extension_registry, } => { let (control_sender, control_receiver) = wrap_control_channel_metrics::( @@ -257,6 +280,7 @@ impl ProcessorWrapper { pdata_senders, pdata_receiver, telemetry, + extension_registry, } } ProcessorWrapper::Shared { @@ -269,7 +293,7 @@ impl ProcessorWrapper { pdata_senders, pdata_receiver, telemetry, - .. + extension_registry, } => { let (control_sender, control_receiver) = wrap_control_channel_metrics::( @@ -292,6 +316,7 @@ impl ProcessorWrapper { pdata_senders, pdata_receiver, telemetry, + extension_registry, } } } @@ -311,6 +336,7 @@ impl ProcessorWrapper { pdata_senders, pdata_receiver, user_config, + extension_registry, .. } => { let message_channel = MessageChannel::new( @@ -323,12 +349,15 @@ impl ProcessorWrapper { })?, ); let default_port = user_config.default_out_port.clone(); - let effect_handler = local::EffectHandler::new( + let mut effect_handler = local::EffectHandler::new( node_id, pdata_senders, default_port, metrics_reporter, ); + if let Some(registry) = extension_registry { + effect_handler.set_extension_registry(registry); + } Ok(ProcessorWrapperRuntime::Local { processor, effect_handler, @@ -342,6 +371,7 @@ impl ProcessorWrapper { pdata_senders, pdata_receiver, user_config, + extension_registry, .. } => { let message_channel = MessageChannel::new( @@ -354,12 +384,15 @@ impl ProcessorWrapper { })?), ); let default_port = user_config.default_out_port.clone(); - let effect_handler = shared::EffectHandler::new( + let mut effect_handler = shared::EffectHandler::new( node_id, pdata_senders, default_port, metrics_reporter, ); + if let Some(registry) = extension_registry { + effect_handler.set_extension_registry(registry); + } Ok(ProcessorWrapperRuntime::Shared { processor, effect_handler, diff --git a/rust/otap-dataflow/crates/engine/src/receiver.rs b/rust/otap-dataflow/crates/engine/src/receiver.rs index f11c52552a..d4ecc16208 100644 --- a/rust/otap-dataflow/crates/engine/src/receiver.rs +++ b/rust/otap-dataflow/crates/engine/src/receiver.rs @@ -14,6 +14,7 @@ use crate::context::PipelineContext; use crate::control::{Controllable, NodeControlMsg, PipelineCtrlMsgSender}; use crate::entity_context::NodeTelemetryGuard; use crate::error::{Error, ReceiverErrorKind}; +use crate::extensions::ExtensionRegistry; use crate::local::message::{LocalReceiver, LocalSender}; use crate::local::receiver as local; use crate::message::{Receiver, Sender}; @@ -56,6 +57,8 @@ pub enum ReceiverWrapper { pdata_receiver: Option>, /// Telemetry guard for node lifecycle cleanup. telemetry: Option, + /// Extension registry for accessing extension traits. + extension_registry: Option, }, /// A receiver with a `Send` implementation. Shared { @@ -78,6 +81,8 @@ pub enum ReceiverWrapper { pdata_receiver: Option>, /// Telemetry guard for node lifecycle cleanup. telemetry: Option, + /// Extension registry for accessing extension traits. + extension_registry: Option, }, } @@ -118,6 +123,7 @@ impl ReceiverWrapper { pdata_senders: HashMap::new(), pdata_receiver: None, telemetry: None, + extension_registry: None, } } @@ -144,6 +150,19 @@ impl ReceiverWrapper { pdata_senders: HashMap::new(), pdata_receiver: None, telemetry: None, + extension_registry: None, + } + } + + /// Sets the extension registry for this receiver. + pub fn set_extension_registry(&mut self, registry: ExtensionRegistry) { + match self { + ReceiverWrapper::Local { + extension_registry, .. + } => *extension_registry = Some(registry), + ReceiverWrapper::Shared { + extension_registry, .. + } => *extension_registry = Some(registry), } } @@ -158,6 +177,7 @@ impl ReceiverWrapper { control_receiver, pdata_senders, pdata_receiver, + extension_registry, .. } => ReceiverWrapper::Local { node_id, @@ -169,6 +189,7 @@ impl ReceiverWrapper { pdata_senders, pdata_receiver, telemetry: Some(guard), + extension_registry, }, ReceiverWrapper::Shared { node_id, @@ -179,6 +200,7 @@ impl ReceiverWrapper { control_receiver, pdata_senders, pdata_receiver, + extension_registry, .. } => ReceiverWrapper::Shared { node_id, @@ -190,6 +212,7 @@ impl ReceiverWrapper { pdata_senders, pdata_receiver, telemetry: Some(guard), + extension_registry, }, } } @@ -218,7 +241,7 @@ impl ReceiverWrapper { pdata_senders, pdata_receiver, telemetry, - .. + extension_registry, } => { let (control_sender, control_receiver) = wrap_control_channel_metrics::( @@ -241,6 +264,7 @@ impl ReceiverWrapper { pdata_senders, pdata_receiver, telemetry, + extension_registry, } } ReceiverWrapper::Shared { @@ -253,7 +277,7 @@ impl ReceiverWrapper { pdata_senders, pdata_receiver, telemetry, - .. + extension_registry, } => { let (control_sender, control_receiver) = wrap_control_channel_metrics::( @@ -276,6 +300,7 @@ impl ReceiverWrapper { pdata_senders, pdata_receiver, telemetry, + extension_registry, } } } @@ -295,6 +320,7 @@ impl ReceiverWrapper { control_receiver, pdata_senders, user_config, + extension_registry, .. }, metrics_reporter, @@ -311,13 +337,16 @@ impl ReceiverWrapper { }; let default_port = user_config.default_out_port.clone(); let ctrl_msg_chan = local::ControlChannel::new(Receiver::Local(control_receiver)); - let effect_handler = local::EffectHandler::new( + let mut effect_handler = local::EffectHandler::new( node_id, msg_senders, default_port, pipeline_ctrl_msg_tx, metrics_reporter, ); + if let Some(registry) = extension_registry { + effect_handler.set_extension_registry(registry); + } receiver.start(ctrl_msg_chan, effect_handler).await } ( @@ -327,6 +356,7 @@ impl ReceiverWrapper { control_receiver, pdata_senders, user_config, + extension_registry, .. }, metrics_reporter, @@ -343,13 +373,16 @@ impl ReceiverWrapper { }; let default_port = user_config.default_out_port.clone(); let ctrl_msg_chan = shared::ControlChannel::new(control_receiver); - let effect_handler = shared::EffectHandler::new( + let mut effect_handler = shared::EffectHandler::new( node_id, msg_senders, default_port, pipeline_ctrl_msg_tx, metrics_reporter, ); + if let Some(registry) = extension_registry { + effect_handler.set_extension_registry(registry); + } receiver.start(ctrl_msg_chan, effect_handler).await } } diff --git a/rust/otap-dataflow/crates/engine/src/runtime_pipeline.rs b/rust/otap-dataflow/crates/engine/src/runtime_pipeline.rs index 2f1bb2ce3d..9f74532346 100644 --- a/rust/otap-dataflow/crates/engine/src/runtime_pipeline.rs +++ b/rust/otap-dataflow/crates/engine/src/runtime_pipeline.rs @@ -13,7 +13,10 @@ use crate::error::{Error, TypedError}; use crate::node::{Node, NodeDefs, NodeId, NodeType, NodeWithPDataReceiver, NodeWithPDataSender}; use crate::pipeline_ctrl::PipelineCtrlMsgManager; use crate::terminal_state::TerminalState; -use crate::{exporter::ExporterWrapper, processor::ProcessorWrapper, receiver::ReceiverWrapper}; +use crate::{ + exporter::ExporterWrapper, extension::ExtensionWrapper, processor::ProcessorWrapper, + receiver::ReceiverWrapper, +}; use otap_df_config::DeployedPipelineKey; use otap_df_config::pipeline::PipelineConfig; use otap_df_telemetry::event::ObservedEventReporter; @@ -35,6 +38,8 @@ pub struct RuntimePipeline { processors: Vec>, /// A map node id to exporter runtime node. exporters: Vec>, + /// A map node id to extension runtime node. + extensions: Vec>, /// A precomputed map of all node IDs to their Node trait objects (? @@@) for efficient access /// Indexed by NodeIndex @@ -70,6 +75,7 @@ impl RuntimePipeline { receivers: Vec>, processors: Vec>, exporters: Vec>, + extensions: Vec>, nodes: NodeDefs, ) -> Self { Self { @@ -77,6 +83,7 @@ impl RuntimePipeline { receivers, processors, exporters, + extensions, nodes, channel_metrics: Default::default(), } @@ -89,7 +96,7 @@ impl RuntimePipeline { /// Returns the number of nodes in the pipeline. #[must_use] pub fn node_count(&self) -> usize { - self.receivers.len() + self.processors.len() + self.exporters.len() + self.receivers.len() + self.processors.len() + self.exporters.len() + self.extensions.len() } /// Returns a reference to the pipeline configuration. @@ -116,6 +123,7 @@ impl RuntimePipeline { receivers, processors, exporters, + extensions, nodes: _nodes, channel_metrics, } = self; @@ -130,6 +138,46 @@ impl RuntimePipeline { let mut futures = FuturesUnordered::new(); let mut control_senders = ControlSenders::default(); + // Spawn extensions before other components to ensure that they are ready to be used by + // processors, receivers, and exporters. + for extension in extensions { + let mut extension = extension; + let node_id = extension.node_id(); + control_senders.register( + node_id.clone(), + NodeType::Extension, + extension.control_sender(), + ); + let telemetry_guard = extension.take_telemetry_guard(); + let node_entity_key = telemetry_guard.as_ref().map(|t| t.entity_key()); + let telemetry_handle = telemetry_guard.as_ref().map(|t| t.handle()); + let pipeline_ctrl_msg_tx = pipeline_ctrl_msg_tx.clone(); + let effect_metrics_reporter = metrics_reporter.clone(); + let final_metrics_reporter = metrics_reporter.clone(); + let fut = async move { + let result = extension + .start(pipeline_ctrl_msg_tx, effect_metrics_reporter) + .await + .map(|terminal_state| { + report_terminal_metrics(&final_metrics_reporter, terminal_state); + }); + drop(telemetry_guard); + result + }; + if let Some(handle) = telemetry_handle { + let input_key = handle.input_channel_key(); + let output_keys = handle.output_channel_keys(); + let node_ctx = + NodeTaskContext::new(node_entity_key, Some(handle), input_key, output_keys); + futures.push(local_tasks.spawn_local(instrument_with_node_context(node_ctx, fut))); + } else if let Some(key) = node_entity_key { + let node_ctx = NodeTaskContext::new(Some(key), None, None, Vec::new()); + futures.push(local_tasks.spawn_local(instrument_with_node_context(node_ctx, fut))); + } else { + futures.push(local_tasks.spawn_local(fut)); + } + } + // Spawn node tasks and register their control senders, scoping telemetry where available. for exporter in exporters { let mut exporter = exporter; @@ -294,6 +342,10 @@ impl RuntimePipeline { let ndef = self.nodes.get(node_id)?; match ndef.ntype { + NodeType::Extension => self + .extensions + .get(ndef.inner.index) + .map(|e| e as &dyn Node), NodeType::Receiver => self .receivers .get(ndef.inner.index) @@ -327,6 +379,7 @@ impl RuntimePipeline { .get_mut(ndef.inner.index) .map(|p| p as &mut dyn NodeWithPDataSender), NodeType::Exporter => None, + NodeType::Extension => None, // Extensions don't send pdata } } @@ -348,6 +401,7 @@ impl RuntimePipeline { .exporters .get_mut(ndef.inner.index) .map(|e| e as &mut dyn NodeWithPDataReceiver), + NodeType::Extension => None, // Extensions don't receive pdata } } @@ -380,6 +434,13 @@ impl RuntimePipeline { .send_control_msg(ctrl_msg) .await } + NodeType::Extension => { + self.extensions + .get(ndef.inner.index) + .expect("precomputed") + .send_control_msg(ctrl_msg) + .await + } } .map_err(|e| TypedError::NodeControlMsgSendError { node_id: node_id.index, diff --git a/rust/otap-dataflow/crates/engine/src/shared/exporter.rs b/rust/otap-dataflow/crates/engine/src/shared/exporter.rs index 0c8abdf05b..2da9a2c8cd 100644 --- a/rust/otap-dataflow/crates/engine/src/shared/exporter.rs +++ b/rust/otap-dataflow/crates/engine/src/shared/exporter.rs @@ -35,6 +35,8 @@ use crate::control::{AckMsg, NackMsg, NodeControlMsg}; use crate::effect_handler::{EffectHandlerCore, TelemetryTimerCancelHandle, TimerCancelHandle}; use crate::error::Error; +use crate::extensions::ExtensionTrait; +use crate::extensions::registry::ExtensionError; use crate::message::Message; use crate::node::NodeId; use crate::shared::message::SharedReceiver; @@ -231,6 +233,39 @@ impl EffectHandler { self.core.node_id() } + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: crate::extensions::ExtensionRegistry) { + self.core.set_extension_registry(registry); + } + + /// Gets an extension trait implementation by extension name. + /// + /// This allows exporters to look up capabilities provided by extensions, + /// such as authentication tokens or credentials. + /// + /// # Type Parameters + /// + /// * `T` - The trait type (e.g., `dyn BearerTokenProvider`). Must implement `ExtensionTrait`. + /// + /// # Errors + /// + /// Returns `ExtensionError::NotFound` if no extension with that name exists. + /// Returns `ExtensionError::TraitNotImplemented` if the extension doesn't implement the trait. + /// + /// # Example + /// + /// ```ignore + /// let token_provider: &dyn BearerTokenProvider = effect_handler + /// .get_extension::("azure_auth")?; + /// let token = token_provider.get_token(); + /// ``` + pub fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + self.core.get_extension::(name) + } + /// Print an info message to stdout. /// /// This method provides a standardized way for exporters to output diff --git a/rust/otap-dataflow/crates/engine/src/shared/extension.rs b/rust/otap-dataflow/crates/engine/src/shared/extension.rs new file mode 100644 index 0000000000..d23ca67c56 --- /dev/null +++ b/rust/otap-dataflow/crates/engine/src/shared/extension.rs @@ -0,0 +1,259 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Trait and structures used to implement shared extensions (Send bound). +//! +//! An extension is a special component that doesn't process pipeline data (pdata). +//! Extensions provide auxiliary services to the pipeline, such as health checks, +//! service discovery, or configuration management. +//! +//! Unlike receivers, processors, and exporters, extensions do not participate in +//! the data flow - they only handle control messages and provide services. +//! +//! # Lifecycle +//! +//! 1. The extension is instantiated and configured +//! 2. The `start` method is called, which begins the extension's operation +//! 3. The extension processes internal control messages +//! 4. The extension shuts down when it receives a `Shutdown` control message or encounters a fatal +//! error +//! +//! # Thread Safety +//! +//! This implementation is designed for use in both single-threaded and multi-threaded environments. +//! The `Extension` trait requires the `Send` bound, enabling the use of thread-safe types. + +use crate::control::{AckMsg, NackMsg, NodeControlMsg}; +use crate::effect_handler::{EffectHandlerCore, TelemetryTimerCancelHandle, TimerCancelHandle}; +use crate::error::Error; +use crate::extensions::{ExtensionError, ExtensionTrait}; +use crate::message::Message; +use crate::node::NodeId; +use crate::shared::message::SharedReceiver; +use crate::terminal_state::TerminalState; +use async_trait::async_trait; +use otap_df_channel::error::RecvError; +use otap_df_telemetry::error::Error as TelemetryError; +use otap_df_telemetry::metrics::{MetricSet, MetricSetHandler}; +use otap_df_telemetry::reporter::MetricsReporter; +use std::marker::PhantomData; +use std::pin::Pin; +use std::time::{Duration, Instant}; +use tokio::time::{Sleep, sleep_until}; + +/// A trait for extensions (Send definition). +/// +/// Extensions are special components that don't process pipeline data. +/// They provide auxiliary services to the pipeline. +#[async_trait] +pub trait Extension { + /// Similar to local::extension::Extension::start, but operates in a Send context. + async fn start( + self: Box, + msg_chan: MessageChannel, + effect_handler: EffectHandler, + ) -> Result; +} + +/// A channel for receiving control messages only (no pdata). +/// +/// Extensions only process control messages, not pipeline data. +pub struct MessageChannel { + control_rx: Option>>, + /// Once a Shutdown is seen, this is set to `Some(instant)` at which point + /// the extension should finish up. + shutting_down_deadline: Option, + /// Holds the ControlMsg::Shutdown until we're ready to return it. + pending_shutdown: Option>, +} + +impl MessageChannel { + /// Creates a new `MessageChannel` with the given control receiver. + #[must_use] + pub fn new(control_rx: SharedReceiver>) -> Self { + MessageChannel { + control_rx: Some(control_rx), + shutting_down_deadline: None, + pending_shutdown: None, + } + } + + /// Asynchronously receives the next control message to process. + /// + /// # Errors + /// + /// Returns a [`RecvError`] if the channel is closed, or if the + /// shutdown deadline has passed. + pub async fn recv(&mut self) -> Result, RecvError> { + let mut sleep_until_deadline: Option>> = None; + + loop { + if self.control_rx.is_none() { + // MessageChannel has been shutdown + return Err(RecvError::Closed); + } + + // Draining mode: Shutdown pending + if let Some(dl) = self.shutting_down_deadline { + // If the deadline has passed, emit the pending Shutdown now. + if Instant::now() >= dl { + let shutdown = self + .pending_shutdown + .take() + .expect("pending_shutdown must exist"); + self.shutdown(); + return Ok(Message::Control(shutdown)); + } + + if sleep_until_deadline.is_none() { + // Create a sleep timer for the deadline + sleep_until_deadline = Some(Box::pin(sleep_until(dl.into()))); + } + + // Wait for deadline or control messages + tokio::select! { + biased; + + // Timer expired + _ = async { sleep_until_deadline.as_mut().expect("sleep_until_deadline").as_mut().await }, if sleep_until_deadline.is_some() => { + let shutdown = self.pending_shutdown + .take() + .expect("pending_shutdown must exist"); + self.shutdown(); + return Ok(Message::Control(shutdown)); + } + + // Control messages (discard after shutdown) + ctrl = self.control_rx.as_mut().expect("control_rx must exist").recv() => match ctrl { + Ok(_) => { + // Discard control messages after shutdown is pending + continue; + } + Err(_) => { + let shutdown = self.pending_shutdown + .take() + .expect("pending_shutdown must exist"); + self.shutdown(); + return Ok(Message::Control(shutdown)); + } + } + } + } + + // Normal mode: wait for control messages + let ctrl_rx = self + .control_rx + .as_mut() + .expect("control_rx must exist in normal mode"); + + match ctrl_rx.recv().await { + Ok(ctrl) => { + if let NodeControlMsg::Shutdown { deadline, .. } = &ctrl { + self.shutting_down_deadline = Some(*deadline); + self.pending_shutdown = Some(ctrl); + // Continue to handle the shutdown in draining mode + continue; + } + return Ok(Message::Control(ctrl)); + } + Err(e) => { + self.shutdown(); + return Err(e); + } + } + } + } + + /// Shuts down the message channel by dropping the control receiver. + fn shutdown(&mut self) { + self.control_rx = None; + } +} + +/// A `Send` implementation of the EffectHandler for extensions. +#[derive(Clone)] +pub struct EffectHandler { + pub(crate) core: EffectHandlerCore, + _pd: PhantomData, +} + +impl EffectHandler { + /// Creates a new shared (Send) `EffectHandler` with the given extension node id and metrics + /// reporter. + #[must_use] + pub fn new(node_id: NodeId, metrics_reporter: MetricsReporter) -> Self { + EffectHandler { + core: EffectHandlerCore::new(node_id, metrics_reporter), + _pd: PhantomData, + } + } + + /// Returns the id of the extension associated with this handler. + #[must_use] + pub fn extension_id(&self) -> NodeId { + self.core.node_id() + } + + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: crate::extensions::ExtensionRegistry) { + self.core.set_extension_registry(registry); + } + + /// Returns an extension trait implementation by name. + /// + /// # Errors + /// + /// Returns an [`ExtensionError`] if the extension is not found or doesn't implement the trait. + pub fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + self.core.get_extension::(name) + } + + /// Print an info message to stdout. + pub async fn info(&self, message: &str) { + self.core.info(message).await; + } + + /// Starts a cancellable periodic timer that emits TimerTick on the control channel. + pub async fn start_periodic_timer( + &self, + duration: Duration, + ) -> Result, Error> { + self.core.start_periodic_timer(duration).await + } + + /// Starts a cancellable periodic telemetry timer that emits CollectTelemetry. + pub async fn start_periodic_telemetry( + &self, + duration: Duration, + ) -> Result, Error> { + self.core.start_periodic_telemetry(duration).await + } + + /// Send an Ack to a node of known-interest. + pub async fn route_ack(&self, ack: AckMsg, cxf: F) -> Result<(), Error> + where + F: FnOnce(AckMsg) -> Option<(usize, AckMsg)>, + { + self.core.route_ack(ack, cxf).await + } + + /// Send a Nack to a node of known-interest. + pub async fn route_nack(&self, nack: NackMsg, cxf: F) -> Result<(), Error> + where + F: FnOnce(NackMsg) -> Option<(usize, NackMsg)>, + { + self.core.route_nack(nack, cxf).await + } + + /// Reports metrics collected by the extension. + #[allow(dead_code)] // Will be used in the future. ToDo report metrics from channel and messages. + pub(crate) fn report_metrics( + &mut self, + metrics: &mut MetricSet, + ) -> Result<(), TelemetryError> { + self.core.report_metrics(metrics) + } +} diff --git a/rust/otap-dataflow/crates/engine/src/shared/mod.rs b/rust/otap-dataflow/crates/engine/src/shared/mod.rs index 3f6f246f76..e71ea6822d 100644 --- a/rust/otap-dataflow/crates/engine/src/shared/mod.rs +++ b/rust/otap-dataflow/crates/engine/src/shared/mod.rs @@ -1,9 +1,10 @@ // Copyright The OpenTelemetry Authors // SPDX-License-Identifier: Apache-2.0 -//! Traits and structs defining the shared (Send) version of receivers, processors, and exporters. +//! Traits and structs defining the shared (Send) version of receivers, processors, exporters, and extensions. pub mod exporter; +pub mod extension; pub mod message; pub mod processor; pub mod receiver; diff --git a/rust/otap-dataflow/crates/engine/src/shared/processor.rs b/rust/otap-dataflow/crates/engine/src/shared/processor.rs index 30237aece5..e5ee044daa 100644 --- a/rust/otap-dataflow/crates/engine/src/shared/processor.rs +++ b/rust/otap-dataflow/crates/engine/src/shared/processor.rs @@ -34,6 +34,7 @@ use crate::control::{AckMsg, NackMsg}; use crate::effect_handler::{EffectHandlerCore, TelemetryTimerCancelHandle, TimerCancelHandle}; use crate::error::{Error, TypedError}; +use crate::extensions::{ExtensionError, ExtensionTrait}; use crate::message::Message; use crate::node::NodeId; use crate::shared::message::SharedSender; @@ -132,6 +133,23 @@ impl EffectHandler { self.core.node_id() } + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: crate::extensions::ExtensionRegistry) { + self.core.set_extension_registry(registry); + } + + /// Returns an extension trait implementation by name. + /// + /// # Errors + /// + /// Returns an [`ExtensionError`] if the extension is not found or doesn't implement the trait. + pub fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + self.core.get_extension::(name) + } + /// Returns the list of connected out ports for this processor. #[must_use] pub fn connected_ports(&self) -> Vec { diff --git a/rust/otap-dataflow/crates/engine/src/shared/receiver.rs b/rust/otap-dataflow/crates/engine/src/shared/receiver.rs index 57eceaccc2..45bdd68449 100644 --- a/rust/otap-dataflow/crates/engine/src/shared/receiver.rs +++ b/rust/otap-dataflow/crates/engine/src/shared/receiver.rs @@ -49,6 +49,8 @@ use std::net::SocketAddr; use std::time::Duration; use tokio::net::TcpListener; +use crate::extensions::{ExtensionError, ExtensionTrait}; + /// A trait for ingress receivers (Send definition). /// /// Receivers are responsible for accepting data from external sources and converting @@ -139,6 +141,23 @@ impl EffectHandler { self.core.node_id() } + /// Sets the extension registry for this effect handler. + pub fn set_extension_registry(&mut self, registry: crate::extensions::ExtensionRegistry) { + self.core.set_extension_registry(registry); + } + + /// Returns an extension trait implementation by name. + /// + /// # Errors + /// + /// Returns an [`ExtensionError`] if the extension is not found or doesn't implement the trait. + pub fn get_extension( + &self, + name: &str, + ) -> Result<&T, ExtensionError> { + self.core.get_extension::(name) + } + /// Returns the list of connected out ports for this receiver. #[must_use] pub fn connected_ports(&self) -> Vec { diff --git a/rust/otap-dataflow/crates/engine/src/testing/extension.rs b/rust/otap-dataflow/crates/engine/src/testing/extension.rs new file mode 100644 index 0000000000..0f8fba5429 --- /dev/null +++ b/rust/otap-dataflow/crates/engine/src/testing/extension.rs @@ -0,0 +1,227 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Test utilities for extensions. +//! +//! These utilities are designed to make testing extensions simpler by abstracting away common +//! setup and lifecycle management. + +use crate::config::ExtensionConfig; +use crate::control::{NodeControlMsg, PipelineCtrlMsgReceiver}; +use crate::extension::ExtensionWrapper; +use crate::local::message::{LocalReceiver, LocalSender}; +use crate::message::Sender; +use crate::shared::message::{SharedReceiver, SharedSender}; +use crate::testing::{CtrlMsgCounters, test_node}; +use otap_df_channel::error::SendError; +use otap_df_config::node::{NodeKind, NodeUserConfig}; +use serde_json::Value; +use std::fmt::Debug; +use std::marker::PhantomData; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +/// A context object that holds transmitters for use in test tasks. +pub struct TestContext { + /// Sender for control messages + control_tx: Sender>, + /// Message counter for tracking processed messages + counters: CtrlMsgCounters, + /// Receiver for pipeline control messages + pipeline_ctrl_msg_receiver: Option>, +} + +impl Clone for TestContext { + fn clone(&self) -> Self { + Self { + control_tx: self.control_tx.clone(), + counters: self.counters.clone(), + pipeline_ctrl_msg_receiver: None, + } + } +} + +impl TestContext { + /// Creates a new TestContext with the given transmitters. + #[must_use] + pub fn new(control_tx: Sender>, counters: CtrlMsgCounters) -> Self { + Self { + control_tx, + counters, + pipeline_ctrl_msg_receiver: None, + } + } + + /// Returns the control message counters. + #[must_use] + pub fn counters(&self) -> CtrlMsgCounters { + self.counters.clone() + } + + /// Takes the pipeline control message receiver from the context. + /// Returns None if already taken. + pub fn take_pipeline_ctrl_receiver(&mut self) -> Option> { + self.pipeline_ctrl_msg_receiver.take() + } + + /// Sends a timer tick control message. + /// + /// # Errors + /// + /// Returns an error if the message could not be sent. + pub async fn send_timer_tick(&self) -> Result<(), SendError>> { + self.control_tx.send(NodeControlMsg::TimerTick {}).await + } + + /// Sends a config control message. + /// + /// # Errors + /// + /// Returns an error if the message could not be sent. + pub async fn send_config(&self, config: Value) -> Result<(), SendError>> { + self.control_tx + .send(NodeControlMsg::Config { config }) + .await + } + + /// Sends a shutdown control message with a specified deadline. + /// + /// # Errors + /// + /// Returns an error if the message could not be sent. + pub async fn send_shutdown( + &self, + deadline: Instant, + reason: &str, + ) -> Result<(), SendError>> { + self.control_tx + .send(NodeControlMsg::Shutdown { + deadline, + reason: reason.to_owned(), + }) + .await + } + + /// Sends a shutdown control message with a deadline relative to now. + /// + /// # Errors + /// + /// Returns an error if the message could not be sent. + pub async fn send_shutdown_in( + &self, + duration: Duration, + ) -> Result<(), SendError>> { + self.send_shutdown(Instant::now() + duration, "test shutdown") + .await + } +} + +/// A runtime for testing extensions. +/// +/// This struct provides methods for setting up and running extension tests +/// in a controlled environment. +pub struct TestRuntime { + _marker: PhantomData, +} + +impl TestRuntime { + /// Creates a new test runtime. + #[must_use] + pub fn new() -> Self { + Self { + _marker: PhantomData, + } + } + + /// Sets up a local extension for testing. + /// + /// Returns the extension wrapper and a test context for controlling it. + #[must_use] + pub fn setup_local_extension( + &self, + name: &str, + create_extension: F, + ) -> (ExtensionWrapper, TestContext) + where + F: FnOnce(CtrlMsgCounters) -> Box>, + { + let counters = CtrlMsgCounters::new(); + let extension = create_extension(counters.clone()); + let name_owned = name.to_owned(); + let node_id = test_node(name_owned.clone()); + let user_config = Arc::new(NodeUserConfig::with_user_config( + NodeKind::Receiver, // Extension is not a config kind yet, use Receiver as placeholder + format!("urn:test:{name}").into(), + Value::Null, + )); + let config = ExtensionConfig::new(name_owned); + + let (control_tx, control_rx) = + otap_df_channel::mpsc::Channel::>::new(32); + + let wrapper = ExtensionWrapper::Local { + node_id, + user_config, + runtime_config: config, + extension, + extension_traits: Some(crate::extensions::registry::ExtensionTraits::new()), + control_sender: LocalSender::mpsc(control_tx.clone()), + control_receiver: LocalReceiver::mpsc(control_rx), + telemetry: None, + extension_registry: None, + }; + + let test_context = TestContext::new(Sender::Local(LocalSender::mpsc(control_tx)), counters); + + (wrapper, test_context) + } + + /// Sets up a shared extension for testing. + /// + /// Returns the extension wrapper and a test context for controlling it. + #[must_use] + pub fn setup_shared_extension( + &self, + name: &str, + create_extension: F, + ) -> (ExtensionWrapper, TestContext) + where + F: FnOnce(CtrlMsgCounters) -> Box + Send>, + { + let counters = CtrlMsgCounters::new(); + let extension = create_extension(counters.clone()); + let name_owned = name.to_owned(); + let node_id = test_node(name_owned.clone()); + let user_config = Arc::new(NodeUserConfig::with_user_config( + NodeKind::Receiver, // Extension is not a config kind yet, use Receiver as placeholder + format!("urn:test:{name}").into(), + Value::Null, + )); + let config = ExtensionConfig::new(name_owned); + + let (control_tx, control_rx) = tokio::sync::mpsc::channel::>(32); + + let wrapper = ExtensionWrapper::Shared { + node_id, + user_config, + runtime_config: config, + extension, + extension_traits: Some(crate::extensions::registry::ExtensionTraits::new()), + control_sender: SharedSender::mpsc(control_tx.clone()), + control_receiver: SharedReceiver::mpsc(control_rx), + telemetry: None, + extension_registry: None, + }; + + let test_context = + TestContext::new(Sender::Shared(SharedSender::mpsc(control_tx)), counters); + + (wrapper, test_context) + } +} + +impl Default for TestRuntime { + fn default() -> Self { + Self::new() + } +} diff --git a/rust/otap-dataflow/crates/engine/src/testing/mod.rs b/rust/otap-dataflow/crates/engine/src/testing/mod.rs index 22bec1333c..d7ae5186c5 100644 --- a/rust/otap-dataflow/crates/engine/src/testing/mod.rs +++ b/rust/otap-dataflow/crates/engine/src/testing/mod.rs @@ -22,6 +22,7 @@ use tokio::runtime::Builder; use tokio::task::LocalSet; pub mod exporter; +pub mod extension; pub mod node; pub mod processor; pub mod receiver; diff --git a/rust/otap-dataflow/crates/otap/Cargo.toml b/rust/otap-dataflow/crates/otap/Cargo.toml index 0a855da9fe..92118cac86 100644 --- a/rust/otap-dataflow/crates/otap/Cargo.toml +++ b/rust/otap-dataflow/crates/otap/Cargo.toml @@ -109,6 +109,8 @@ byte-unit.workspace = true experimental-tls = ["dep:rustls", "dep:rustls-pki-types", "dep:tokio-rustls", "dep:rustls-native-certs", "tonic/tls-ring", "dep:arc-swap", "dep:notify"] # Base experimental feature - enabled by all experimental exporters experimental-exporters = [] +# Base experimental feature - enabled by all experimental extensions +experimental-extensions = [] # Experimental exporters geneva-exporter = ["experimental-exporters", "dep:geneva-uploader", "dep:opentelemetry-proto"] azure-monitor-exporter = [ @@ -120,6 +122,12 @@ azure-monitor-exporter = [ "dep:ahash", "dep:sysinfo", ] +# Experimental extensions +azure-identity-auth-extension = [ + "experimental-extensions", + "dep:azure_identity", + "dep:azure_core", +] azure = ["object_store/azure", "object_store/cloud", "dep:azure_identity", "dep:azure_core"] # Experimental processors experimental-processors = [] diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/config.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/config.rs new file mode 100644 index 0000000000..49c9246883 --- /dev/null +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/config.rs @@ -0,0 +1,156 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Configuration types for the Azure Identity Auth Extension. + +use serde::Deserialize; + +use super::Error; + +/// Authentication method for Azure. +#[derive(Debug, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "lowercase")] +pub enum AuthMethod { + /// Use Managed Identity (system or user-assigned with client_id). + #[serde(alias = "msi", alias = "managed_identity")] + #[default] + ManagedIdentity, + + /// Use developer tools (Azure CLI, Azure Developer CLI). + #[serde(alias = "dev", alias = "developer", alias = "cli")] + Development, +} + +impl std::fmt::Display for AuthMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthMethod::ManagedIdentity => write!(f, "managed_identity"), + AuthMethod::Development => write!(f, "development"), + } + } +} + +/// Configuration for the Azure Identity Auth Extension. +#[derive(Debug, Deserialize, Clone)] +#[serde(deny_unknown_fields)] +pub struct Config { + /// Authentication method to use. + #[serde(default)] + pub method: AuthMethod, + + /// Client ID for user-assigned managed identity (optional). + /// Only used when method is ManagedIdentity. + /// If not provided with ManagedIdentity, system-assigned identity will be used. + pub client_id: Option, + + /// OAuth scope for token acquisition. + /// Defaults to "https://management.azure.com/.default" for general Azure management. + #[serde(default = "default_scope")] + pub scope: String, +} + +impl Default for Config { + fn default() -> Self { + Self { + method: AuthMethod::default(), + client_id: None, + scope: default_scope(), + } + } +} + +impl Config { + /// Validate the configuration. + pub fn validate(&self) -> Result<(), Error> { + // Validate scope is not empty + if self.scope.is_empty() { + return Err(Error::Config("OAuth scope cannot be empty".to_string())); + } + + Ok(()) + } +} + +fn default_scope() -> String { + "https://management.azure.com/.default".to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = Config::default(); + assert_eq!(config.method, AuthMethod::ManagedIdentity); + assert!(config.client_id.is_none()); + assert_eq!(config.scope, "https://management.azure.com/.default"); + } + + #[test] + fn test_auth_method_display() { + assert_eq!( + format!("{}", AuthMethod::ManagedIdentity), + "managed_identity" + ); + assert_eq!(format!("{}", AuthMethod::Development), "development"); + } + + #[test] + fn test_config_validation_empty_scope() { + let config = Config { + method: AuthMethod::ManagedIdentity, + client_id: None, + scope: "".to_string(), + }; + let result = config.validate(); + assert!(result.is_err()); + } + + #[test] + fn test_config_validation_valid() { + let config = Config::default(); + let result = config.validate(); + assert!(result.is_ok()); + } + + #[test] + fn test_deserialize_managed_identity_system_assigned() { + let json = r#"{ + "method": "managed_identity", + "scope": "https://monitor.azure.com/.default" + }"#; + let config: Config = serde_json::from_str(json).unwrap(); + assert_eq!(config.method, AuthMethod::ManagedIdentity); + assert!(config.client_id.is_none()); + } + + #[test] + fn test_deserialize_managed_identity_user_assigned() { + let json = r#"{ + "method": "msi", + "client_id": "12345-abcde", + "scope": "https://monitor.azure.com/.default" + }"#; + let config: Config = serde_json::from_str(json).unwrap(); + assert_eq!(config.method, AuthMethod::ManagedIdentity); + assert_eq!(config.client_id, Some("12345-abcde".to_string())); + } + + #[test] + fn test_deserialize_development() { + let json = r#"{ + "method": "development" + }"#; + let config: Config = serde_json::from_str(json).unwrap(); + assert_eq!(config.method, AuthMethod::Development); + } + + #[test] + fn test_deserialize_with_defaults() { + let json = r#"{}"#; + let config: Config = serde_json::from_str(json).unwrap(); + assert_eq!(config.method, AuthMethod::ManagedIdentity); + assert_eq!(config.scope, "https://management.azure.com/.default"); + } +} diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/error.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/error.rs new file mode 100644 index 0000000000..478f053a62 --- /dev/null +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/error.rs @@ -0,0 +1,132 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Error types for the Azure Identity Auth Extension. + +use super::config::AuthMethod; + +/// Error definitions for Azure Identity Auth Extension. +#[derive(thiserror::Error, Debug)] +pub enum Error { + // ==================== Configuration Errors ==================== + /// Error during configuration of a component. + #[error("Configuration error: {0}")] + Config(String), + + // ==================== Authentication Errors ==================== + /// Authentication/authorization error. + #[error("Auth error ({kind})")] + Auth { + /// The kind of authentication error. + kind: AuthErrorKind, + /// The underlying Azure error, if any. + #[source] + source: Option, + }, + + // ==================== Internal Errors ==================== + /// Shutdown requested. + #[error("Shutdown requested: {reason}")] + Shutdown { + /// The reason for shutdown. + reason: String, + }, +} + +/// Specific authentication error variants. +#[derive(Debug, Clone, PartialEq)] +pub enum AuthErrorKind { + /// Failed to create the credential provider. + CreateCredential { + /// The authentication method that failed. + method: AuthMethod, + }, + + /// Failed to acquire a token. + TokenAcquisition, + + /// Token has expired and refresh failed. + TokenExpired, +} + +impl std::fmt::Display for AuthErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AuthErrorKind::CreateCredential { method } => { + write!(f, "failed to create credential for method: {}", method) + } + AuthErrorKind::TokenAcquisition => write!(f, "failed to acquire token"), + AuthErrorKind::TokenExpired => write!(f, "token expired and refresh failed"), + } + } +} + +impl Error { + /// Creates a new credential creation error. + #[must_use] + pub fn create_credential(method: AuthMethod, source: azure_core::error::Error) -> Self { + Error::Auth { + kind: AuthErrorKind::CreateCredential { method }, + source: Some(source), + } + } + + /// Creates a new token acquisition error. + #[must_use] + pub fn token_acquisition(source: azure_core::error::Error) -> Self { + Error::Auth { + kind: AuthErrorKind::TokenAcquisition, + source: Some(source), + } + } + + /// Creates a new token expired error. + #[must_use] + pub fn token_expired() -> Self { + Error::Auth { + kind: AuthErrorKind::TokenExpired, + source: None, + } + } + + /// Creates a new shutdown error. + pub fn shutdown(reason: impl Into) -> Self { + Error::Shutdown { + reason: reason.into(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_error_display() { + let err = Error::Config("test error".to_string()); + assert_eq!(format!("{}", err), "Configuration error: test error"); + } + + #[test] + fn test_auth_error_kind_display() { + let kind = AuthErrorKind::CreateCredential { + method: AuthMethod::ManagedIdentity, + }; + assert!(format!("{}", kind).contains("managed_identity")); + + let kind = AuthErrorKind::TokenAcquisition; + assert_eq!(format!("{}", kind), "failed to acquire token"); + + let kind = AuthErrorKind::TokenExpired; + assert_eq!(format!("{}", kind), "token expired and refresh failed"); + } + + #[test] + fn test_shutdown_error() { + let err = Error::shutdown("test reason"); + match err { + Error::Shutdown { reason } => assert_eq!(reason, "test reason"), + _ => panic!("Expected Shutdown error"), + } + } +} diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/extension.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/extension.rs new file mode 100644 index 0000000000..eb1154ebac --- /dev/null +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/extension.rs @@ -0,0 +1,752 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Azure Identity Auth Extension implementation. + +use async_trait::async_trait; +use azure_core::credentials::{AccessToken, TokenCredential}; +use azure_identity::{ + DeveloperToolsCredential, DeveloperToolsCredentialOptions, ManagedIdentityCredential, + ManagedIdentityCredentialOptions, UserAssignedId, +}; +use otap_df_engine::extensions::{BearerToken, BearerTokenProvider}; +use std::sync::Arc; +use tokio::sync::watch; + +use otap_df_engine::control::NodeControlMsg; +use otap_df_engine::error::Error as EngineError; +use otap_df_engine::local::extension::{EffectHandler, Extension}; +use otap_df_engine::message::{Message, MessageChannel}; +use otap_df_engine::terminal_state::TerminalState; + +use crate::pdata::OtapPdata; + +use super::config::{AuthMethod, Config}; +use super::error::Error; + +/// Minimum delay between token refresh retry attempts in seconds. +const MIN_RETRY_DELAY_SECS: f64 = 5.0; +/// Maximum delay between token refresh retry attempts in seconds. +const MAX_RETRY_DELAY_SECS: f64 = 30.0; +/// Maximum jitter percentage (±10%) to add to retry delays. +const MAX_RETRY_JITTER_RATIO: f64 = 0.10; + +/// Buffer time before token expiry to trigger refresh (in seconds). +/// Tokens will be refreshed ~5 minutes before they expire. +const TOKEN_EXPIRY_BUFFER_SECS: u64 = 299; +/// Minimum interval between token refresh attempts (in seconds). +const MIN_TOKEN_REFRESH_INTERVAL_SECS: u64 = 10; +/// Retry interval when token refresh fails (in seconds). +const TOKEN_REFRESH_RETRY_SECS: u64 = 10; + +/// Azure Identity Auth Extension. +/// +/// This extension provides Azure authentication services to the pipeline. +/// It manages Azure credentials and provides token acquisition capabilities. +pub struct AzureIdentityAuthExtension { + /// The Azure credential provider. + credential: Arc, + /// The OAuth scope for token acquisition. + scope: String, + /// The authentication method being used. + method: AuthMethod, + /// Sender for broadcasting token refresh events to subscribers. + token_sender: watch::Sender>, +} + +// TODO: Remove print_stdout after logging is set up +#[allow(clippy::print_stdout)] +impl AzureIdentityAuthExtension { + /// Creates a new Azure Identity Auth Extension with the given configuration. + pub fn new(config: Config) -> Result { + let credential = Self::create_credential(&config)?; + let (token_sender, _) = watch::channel(None); + + Ok(Self { + credential, + scope: config.scope.clone(), + method: config.method.clone(), + token_sender, + }) + } + + /// Creates a credential provider based on the configuration. + fn create_credential(config: &Config) -> Result, Error> { + match config.method { + AuthMethod::ManagedIdentity => { + let mut options = ManagedIdentityCredentialOptions::default(); + + if let Some(client_id) = &config.client_id { + println!( + "[AzureIdentityAuthExtension] Using user-assigned managed identity with client_id: {}", + client_id + ); + options.user_assigned_id = Some(UserAssignedId::ClientId(client_id.clone())); + } else { + println!("[AzureIdentityAuthExtension] Using system-assigned managed identity"); + } + + Ok(ManagedIdentityCredential::new(Some(options)) + .map_err(|e| Error::create_credential(AuthMethod::ManagedIdentity, e))?) + } + AuthMethod::Development => { + println!( + "[AzureIdentityAuthExtension] Using developer tools credential (Azure CLI / Azure Developer CLI)" + ); + Ok( + DeveloperToolsCredential::new(Some(DeveloperToolsCredentialOptions::default())) + .map_err(|e| Error::create_credential(AuthMethod::Development, e))?, + ) + } + } + } + + /// Gets a token from the credential provider. + async fn get_token_internal(&self) -> Result { + let token_response = self + .credential + .get_token( + &[&self.scope], + Some(azure_core::credentials::TokenRequestOptions::default()), + ) + .await + .map_err(Error::token_acquisition)?; + + Ok(token_response) + } + + /// Gets a token with retry logic. + /// + /// This method implements exponential backoff with jitter for retrying + /// token acquisition on failure. + pub async fn get_token(&self) -> Result { + let mut attempt = 0_i32; + loop { + attempt += 1; + + match self.get_token_internal().await { + Ok(token) => { + println!( + "[AzureIdentityAuthExtension] Obtained access token, expires on {}", + token.expires_on + ); + return Ok(token); + } + Err(e) => { + println!( + "[AzureIdentityAuthExtension] Failed to obtain access token (attempt {}): {}", + attempt, e + ); + } + } + + // Calculate exponential backoff: 5s, 10s, 20s, 30s (capped) + let base_delay_secs = MIN_RETRY_DELAY_SECS * 2.0_f64.powi(attempt - 1); + let capped_delay_secs = base_delay_secs.min(MAX_RETRY_DELAY_SECS); + + // Add jitter: random value between -10% and +10% of the delay + let jitter_range = capped_delay_secs * MAX_RETRY_JITTER_RATIO; + let jitter = if jitter_range > 0.0 { + let random_factor = rand::random::() * 2.0 - 1.0; + random_factor * jitter_range + } else { + 0.0 + }; + + let delay_secs = (capped_delay_secs + jitter).max(1.0); + let delay = tokio::time::Duration::from_secs_f64(delay_secs); + + println!( + "[AzureIdentityAuthExtension] Retrying in {:.1}s...", + delay_secs + ); + tokio::time::sleep(delay).await; + } + } + + /// Calculates when the next token refresh should occur. + /// + /// This schedules refresh before the token expires (with a buffer), + /// but ensures we don't refresh too frequently. + fn get_next_token_refresh(token: &BearerToken) -> tokio::time::Instant { + let now_secs = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_secs() as i64) + .unwrap_or(0); + + let duration_remaining = if token.expires_on > now_secs { + std::time::Duration::from_secs((token.expires_on - now_secs) as u64) + } else { + std::time::Duration::ZERO + }; + + let token_valid_until = tokio::time::Instant::now() + duration_remaining; + let next_token_refresh = + token_valid_until - tokio::time::Duration::from_secs(TOKEN_EXPIRY_BUFFER_SECS); + std::cmp::max( + next_token_refresh, + tokio::time::Instant::now() + + tokio::time::Duration::from_secs(MIN_TOKEN_REFRESH_INTERVAL_SECS), + ) + } + + /// Returns the authentication method being used. + #[must_use] + pub fn method(&self) -> &AuthMethod { + &self.method + } + + /// Returns the OAuth scope. + #[must_use] + pub fn scope(&self) -> &str { + &self.scope + } + + /// Returns a clone of the credential for sharing with other components. + #[must_use] + pub fn credential(&self) -> Arc { + self.credential.clone() + } +} + +#[async_trait] +impl BearerTokenProvider for AzureIdentityAuthExtension { + async fn get_token(&self) -> Result { + let access_token = AzureIdentityAuthExtension::get_token(self).await?; + + Ok(BearerToken::new( + access_token.token.secret().to_string(), + access_token.expires_on.unix_timestamp(), + )) + } + + fn subscribe_token_refresh(&self) -> watch::Receiver> { + self.token_sender.subscribe() + } +} + +#[async_trait(?Send)] +impl Extension for AzureIdentityAuthExtension { + #[allow(clippy::print_stdout)] + async fn start( + self: Box, + mut msg_chan: MessageChannel, + effect_handler: EffectHandler, + ) -> Result { + effect_handler + .info(&format!( + "[AzureIdentityAuthExtension] Started with {} authentication", + self.method + )) + .await; + + // Fetch initial token immediately + let mut next_token_refresh = tokio::time::Instant::now(); + + // Main event loop - extensions handle control messages and proactive token refresh + loop { + tokio::select! { + biased; + + // Proactive token refresh - keeps Azure Identity's internal cache warm + _ = tokio::time::sleep_until(next_token_refresh) => { + match AzureIdentityAuthExtension::get_token(&self).await { + Ok(access_token) => { + let bearer_token = BearerToken::new( + access_token.token.secret().to_string(), + access_token.expires_on.unix_timestamp(), + ); + + // Broadcast the new token to all subscribers + let _ = self.token_sender.send(Some(bearer_token.clone())); + + // Schedule next refresh + next_token_refresh = Self::get_next_token_refresh(&bearer_token); + + let refresh_in = next_token_refresh.saturating_duration_since(tokio::time::Instant::now()); + let total_secs = refresh_in.as_secs(); + let hours = total_secs / 3600; + let minutes = (total_secs % 3600) / 60; + let seconds = total_secs % 60; + + effect_handler + .info(&format!( + "[AzureIdentityAuthExtension] Token refreshed, next refresh in {}h {}m {}s", + hours, minutes, seconds + )) + .await; + } + Err(e) => { + effect_handler + .info(&format!("[AzureIdentityAuthExtension] Failed to refresh token: {:?}, retrying in {}s", e, TOKEN_REFRESH_RETRY_SECS)) + .await; + // Retry after a short delay + next_token_refresh = tokio::time::Instant::now() + + tokio::time::Duration::from_secs(TOKEN_REFRESH_RETRY_SECS); + } + } + } + + // Handle control messages + msg = msg_chan.recv() => { + match msg? { + Message::Control(NodeControlMsg::Shutdown { reason, .. }) => { + effect_handler + .info(&format!( + "[AzureIdentityAuthExtension] Shutting down: {}", + reason + )) + .await; + break; + } + Message::Control(NodeControlMsg::TimerTick {}) => { + // Timer ticks handled by tokio::select sleep_until + } + Message::Control(NodeControlMsg::Config { config }) => { + // Handle dynamic configuration updates + effect_handler + .info(&format!( + "[AzureIdentityAuthExtension] Received config update: {:?}", + config + )) + .await; + } + Message::PData(_) => { + // Extensions don't process pipeline data - this shouldn't happen + effect_handler + .info("[AzureIdentityAuthExtension] Received unexpected PData message") + .await; + } + _ => { + // Handle other control messages as needed + } + } + } + } + } + + Ok(TerminalState::default()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use azure_core::credentials::TokenRequestOptions; + use azure_core::time::OffsetDateTime; + use std::sync::atomic::{AtomicUsize, Ordering}; + + #[derive(Debug)] + struct MockCredential { + token: String, + expires_in: azure_core::time::Duration, + call_count: Arc, + } + + fn make_mock_credential( + token: &str, + expires_in: azure_core::time::Duration, + call_count: Arc, + ) -> Arc { + let cred: Arc = Arc::new(MockCredential { + token: token.to_string(), + expires_in, + call_count, + }); + cred + } + + #[async_trait::async_trait] + impl TokenCredential for MockCredential { + async fn get_token( + &self, + _scopes: &[&str], + _options: Option>, + ) -> azure_core::Result { + let _ = self.call_count.fetch_add(1, Ordering::SeqCst); + + Ok(AccessToken { + token: self.token.clone().into(), + expires_on: OffsetDateTime::now_utc() + self.expires_in, + }) + } + } + + impl AzureIdentityAuthExtension { + /// Creates an extension with a mock credential for testing. + #[cfg(test)] + pub fn from_mock( + credential: Arc, + scope: String, + method: AuthMethod, + ) -> Self { + let (token_sender, _) = watch::channel(None); + Self { + credential, + scope, + method, + token_sender, + } + } + } + + // ==================== Construction Tests ==================== + + #[tokio::test] + async fn test_from_mock_creates_extension() { + let credential = make_mock_credential( + "test_token", + azure_core::time::Duration::minutes(60), + Arc::new(AtomicUsize::new(0)), + ); + + let ext = AzureIdentityAuthExtension::from_mock( + credential, + "test_scope".to_string(), + AuthMethod::Development, + ); + assert_eq!(ext.scope(), "test_scope"); + assert_eq!(ext.method(), &AuthMethod::Development); + } + + #[tokio::test] + async fn test_new_with_managed_identity_system_assigned() { + let config = Config { + method: AuthMethod::ManagedIdentity, + client_id: None, + scope: "https://test.scope".to_string(), + }; + + let ext = AzureIdentityAuthExtension::new(config); + assert!(ext.is_ok()); + let ext = ext.unwrap(); + assert_eq!(ext.scope(), "https://test.scope"); + } + + #[tokio::test] + async fn test_new_with_managed_identity_user_assigned() { + let config = Config { + method: AuthMethod::ManagedIdentity, + client_id: Some("test-client-id".to_string()), + scope: "https://test.scope".to_string(), + }; + + let ext = AzureIdentityAuthExtension::new(config); + assert!(ext.is_ok()); + } + + #[tokio::test] + async fn test_new_with_development_auth() { + let config = Config { + method: AuthMethod::Development, + client_id: None, + scope: "https://test.scope".to_string(), + }; + + // May fail if Azure CLI not installed - both outcomes are valid + let result = AzureIdentityAuthExtension::new(config); + match result { + Ok(ext) => assert_eq!(ext.scope(), "https://test.scope"), + Err(Error::Auth { + kind: super::super::error::AuthErrorKind::CreateCredential { method }, + .. + }) => { + assert_eq!(method, AuthMethod::Development); + } + Err(err) => panic!("Unexpected error type: {:?}", err), + } + } + + // ==================== Token Fetching Tests ==================== + + #[tokio::test] + async fn test_get_token_internal_returns_valid_token() { + let call_count = Arc::new(AtomicUsize::new(0)); + let credential = make_mock_credential( + "test_token", + azure_core::time::Duration::minutes(60), + call_count.clone(), + ); + + let ext = AzureIdentityAuthExtension::from_mock( + credential, + "scope".to_string(), + AuthMethod::Development, + ); + + let token = ext.get_token_internal().await.unwrap(); + assert_eq!(token.token.secret(), "test_token"); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_get_token_internal_calls_credential_each_time() { + let call_count = Arc::new(AtomicUsize::new(0)); + let credential = make_mock_credential( + "test_token", + azure_core::time::Duration::minutes(60), + call_count.clone(), + ); + + let ext = AzureIdentityAuthExtension::from_mock( + credential, + "scope".to_string(), + AuthMethod::Development, + ); + + // Each call to get_token_internal should call the credential + let _ = ext.get_token_internal().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + + let _ = ext.get_token_internal().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + + let _ = ext.get_token_internal().await.unwrap(); + assert_eq!(call_count.load(Ordering::SeqCst), 3); + } + + // ==================== Credential Sharing Tests ==================== + + #[tokio::test] + async fn test_credential_returns_shared_reference() { + let call_count = Arc::new(AtomicUsize::new(0)); + let credential = make_mock_credential( + "test_token", + azure_core::time::Duration::minutes(60), + call_count.clone(), + ); + + let ext = AzureIdentityAuthExtension::from_mock( + credential, + "scope".to_string(), + AuthMethod::Development, + ); + + // Get shared credential + let shared_cred = ext.credential(); + + // Both should work + let token1 = ext.get_token_internal().await.unwrap(); + let token2 = shared_cred + .get_token(&["scope"], Some(TokenRequestOptions::default())) + .await + .unwrap(); + + assert_eq!(token1.token.secret(), "test_token"); + assert_eq!(token2.token.secret(), "test_token"); + assert_eq!(call_count.load(Ordering::SeqCst), 2); + } + + // ==================== Error Handling Tests ==================== + + #[tokio::test] + async fn test_get_token_internal_propagates_credential_error() { + #[derive(Debug)] + struct FailingCredential; + + #[async_trait::async_trait] + impl TokenCredential for FailingCredential { + async fn get_token( + &self, + _scopes: &[&str], + _options: Option>, + ) -> azure_core::Result { + Err(azure_core::error::Error::new( + azure_core::error::ErrorKind::Credential, + "Mock credential failure", + )) + } + } + + let credential: Arc = Arc::new(FailingCredential); + let ext = AzureIdentityAuthExtension::from_mock( + credential, + "scope".to_string(), + AuthMethod::Development, + ); + + let result = ext.get_token_internal().await; + assert!(result.is_err()); + match result.unwrap_err() { + Error::Auth { + kind: super::super::error::AuthErrorKind::TokenAcquisition, + .. + } => {} + err => panic!("Expected Auth token acquisition error, got: {:?}", err), + } + } + + // ==================== BearerTokenProvider Trait Tests ==================== + + #[tokio::test] + async fn test_bearer_token_provider_get_token() { + let call_count = Arc::new(AtomicUsize::new(0)); + let credential = make_mock_credential( + "bearer_test_token", + azure_core::time::Duration::minutes(60), + call_count.clone(), + ); + + let ext = AzureIdentityAuthExtension::from_mock( + credential, + "scope".to_string(), + AuthMethod::Development, + ); + + // Use the BearerTokenProvider trait method + let token: BearerToken = BearerTokenProvider::get_token(&ext).await.unwrap(); + assert_eq!(token.token.secret(), "bearer_test_token"); + assert!(token.expires_on > 0); + assert_eq!(call_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn test_bearer_token_provider_subscribe_token_refresh() { + let credential = make_mock_credential( + "test_token", + azure_core::time::Duration::minutes(60), + Arc::new(AtomicUsize::new(0)), + ); + + let ext = AzureIdentityAuthExtension::from_mock( + credential, + "scope".to_string(), + AuthMethod::Development, + ); + + // Get a subscriber + let mut rx = ext.subscribe_token_refresh(); + + // Initially should be None + assert!(rx.borrow().is_none()); + + // Simulate token broadcast (using internal sender) + let new_token = BearerToken::new("refreshed_token".to_string(), 12345); + let _ = ext.token_sender.send(Some(new_token.clone())); + + // Subscriber should receive the update + rx.changed().await.unwrap(); + let received = rx.borrow(); + assert!(received.is_some()); + let received_token = received.as_ref().unwrap(); + assert_eq!(received_token.token.secret(), "refreshed_token"); + assert_eq!(received_token.expires_on, 12345); + } + + #[tokio::test] + async fn test_multiple_subscribers_receive_token_updates() { + let credential = make_mock_credential( + "test_token", + azure_core::time::Duration::minutes(60), + Arc::new(AtomicUsize::new(0)), + ); + + let ext = AzureIdentityAuthExtension::from_mock( + credential, + "scope".to_string(), + AuthMethod::Development, + ); + + // Create multiple subscribers + let mut rx1 = ext.subscribe_token_refresh(); + let mut rx2 = ext.subscribe_token_refresh(); + + // Broadcast a token + let token = BearerToken::new("broadcast_token".to_string(), 99999); + let _ = ext.token_sender.send(Some(token)); + + // Both subscribers should receive the update + rx1.changed().await.unwrap(); + rx2.changed().await.unwrap(); + + assert_eq!( + rx1.borrow().as_ref().unwrap().token.secret(), + "broadcast_token" + ); + assert_eq!( + rx2.borrow().as_ref().unwrap().token.secret(), + "broadcast_token" + ); + } + + // ==================== Token Refresh Scheduling Tests ==================== + + #[test] + fn test_get_next_token_refresh_schedules_before_expiry() { + // Token expires in 10 minutes (600 seconds) + let now_secs = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + let token = BearerToken::new("test".to_string(), now_secs + 600); + let next_refresh = AzureIdentityAuthExtension::get_next_token_refresh(&token); + + // Should refresh before expiry (600s - 299s buffer = ~301s from now) + // But at least MIN_TOKEN_REFRESH_INTERVAL_SECS (10s) from now + let now = tokio::time::Instant::now(); + let min_expected = now + tokio::time::Duration::from_secs(MIN_TOKEN_REFRESH_INTERVAL_SECS); + + assert!(next_refresh >= min_expected); + } + + #[test] + fn test_get_next_token_refresh_respects_minimum_interval() { + // Token expires very soon (in 5 seconds) + let now_secs = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + let token = BearerToken::new("test".to_string(), now_secs + 5); + let next_refresh = AzureIdentityAuthExtension::get_next_token_refresh(&token); + + // Should still wait at least MIN_TOKEN_REFRESH_INTERVAL_SECS (allowing 1s tolerance) + let now = tokio::time::Instant::now(); + let min_expected = + now + tokio::time::Duration::from_secs(MIN_TOKEN_REFRESH_INTERVAL_SECS - 1); + + assert!(next_refresh >= min_expected); + } + + #[test] + fn test_get_next_token_refresh_handles_expired_token() { + // Token already expired (in the past) + let now_secs = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + let token = BearerToken::new("test".to_string(), now_secs - 100); + let next_refresh = AzureIdentityAuthExtension::get_next_token_refresh(&token); + + // Should schedule refresh at minimum interval (allowing 1s tolerance) + let now = tokio::time::Instant::now(); + let min_expected = + now + tokio::time::Duration::from_secs(MIN_TOKEN_REFRESH_INTERVAL_SECS - 1); + + assert!(next_refresh >= min_expected); + } + + #[test] + fn test_get_next_token_refresh_long_lived_token() { + // Token expires in 1 hour (3600 seconds) + let now_secs = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() as i64; + + let token = BearerToken::new("test".to_string(), now_secs + 3600); + let next_refresh = AzureIdentityAuthExtension::get_next_token_refresh(&token); + + // Should refresh ~5 minutes before expiry (3600 - 299 = 3301s from now) + let now = tokio::time::Instant::now(); + let expected_approx = + now + tokio::time::Duration::from_secs(3600 - TOKEN_EXPIRY_BUFFER_SECS); + + // Allow some tolerance for timing + let tolerance = tokio::time::Duration::from_secs(2); + assert!(next_refresh >= expected_approx - tolerance); + assert!(next_refresh <= expected_approx + tolerance); + } +} diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/mod.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/mod.rs new file mode 100644 index 0000000000..05cfb6454d --- /dev/null +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_identity_auth_extension/mod.rs @@ -0,0 +1,102 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +//! Azure Identity Auth Extension for OTAP. +//! +//! Provides Azure authentication services to the pipeline using Azure Identity. +//! This extension manages token acquisition and refresh, making credentials +//! available to other components (e.g., exporters) that need Azure authentication. +//! +//! # Features +//! +//! - Managed Identity authentication (system or user-assigned) +//! - Developer tools authentication (Azure CLI, Azure Developer CLI) +//! - Automatic token refresh with exponential backoff retry +//! - Shared credential access across pipeline components + +use linkme::distributed_slice; +use otap_df_config::node::NodeUserConfig; +use otap_df_engine::config::ExtensionConfig; +use otap_df_engine::context::PipelineContext; +use otap_df_engine::extension::ExtensionWrapper; +use otap_df_engine::extensions::BearerTokenProvider; +use otap_df_engine::node::NodeId; +use otap_df_engine::{ExtensionFactory, extension_traits}; +use serde_json; +use std::sync::Arc; + +use crate::OTAP_EXTENSION_FACTORIES; +use crate::pdata::OtapPdata; + +mod config; +mod error; +mod extension; + +pub use config::{AuthMethod, Config}; +pub use error::Error; +pub use extension::AzureIdentityAuthExtension; + +/// URN identifying the Azure Identity Auth Extension in configuration pipelines. +pub const AZURE_IDENTITY_AUTH_EXTENSION_URN: &str = "urn:otel:azureidentityauth:extension"; + +/// Register Azure Identity Auth Extension with the OTAP extension factory. +/// +/// Uses the `distributed_slice` macro for automatic discovery by the dataflow engine. +#[allow(unsafe_code)] +#[distributed_slice(OTAP_EXTENSION_FACTORIES)] +pub static AZURE_IDENTITY_AUTH_EXTENSION: ExtensionFactory = ExtensionFactory { + name: AZURE_IDENTITY_AUTH_EXTENSION_URN, + create: |_: PipelineContext, + node: NodeId, + node_config: Arc, + extension_config: &ExtensionConfig| { + // Deserialize user config JSON into typed Config + let cfg: Config = serde_json::from_value(node_config.config.clone()).map_err(|e| { + otap_df_config::error::Error::InvalidUserConfig { + error: e.to_string(), + } + })?; + + // Validate the configuration + cfg.validate() + .map_err(|e| otap_df_config::error::Error::InvalidUserConfig { + error: e.to_string(), + })?; + + // Create the extension + let extension = AzureIdentityAuthExtension::new(cfg).map_err(|e| { + otap_df_config::error::Error::InvalidUserConfig { + error: e.to_string(), + } + })?; + + Ok(ExtensionWrapper::local( + extension, + extension_traits!(AzureIdentityAuthExtension => BearerTokenProvider), + node, + node_config, + extension_config, + )) + }, +}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extension_urn() { + assert_eq!( + AZURE_IDENTITY_AUTH_EXTENSION_URN, + "urn:otel:azureidentityauth:extension" + ); + } + + #[test] + fn test_factory_name_matches_urn() { + assert_eq!( + AZURE_IDENTITY_AUTH_EXTENSION.name, + AZURE_IDENTITY_AUTH_EXTENSION_URN + ); + } +} diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/auth.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/auth.rs deleted file mode 100644 index e08fbebe10..0000000000 --- a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/auth.rs +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -use azure_core::credentials::{AccessToken, TokenCredential}; -use azure_identity::{ - DeveloperToolsCredential, DeveloperToolsCredentialOptions, ManagedIdentityCredential, - ManagedIdentityCredentialOptions, UserAssignedId, -}; -use std::sync::Arc; - -use super::Error; -use super::config::{AuthConfig, AuthMethod}; - -/// Minimum delay between token refresh retry attempts in seconds. -const MIN_RETRY_DELAY_SECS: f64 = 5.0; -/// Maximum delay between token refresh retry attempts in seconds. -const MAX_RETRY_DELAY_SECS: f64 = 30.0; -/// Maximum jitter percentage (±10%) to add to retry delays. -const MAX_RETRY_JITTER_RATIO: f64 = 0.10; - -#[derive(Clone, Debug)] -// TODO - Consolidate with crates/otap/src/{cloud_auth,object_store)/azure.rs -#[allow(clippy::print_stdout)] -pub struct Auth { - credential: Arc, - scope: String, -} - -// TODO: Remove print_stdout after logging is set up -#[allow(clippy::print_stdout)] -impl Auth { - pub fn new(auth_config: &AuthConfig) -> Result { - let credential = Self::create_credential(auth_config)?; - - Ok(Self { - credential, - scope: auth_config.scope.clone(), - }) - } - - #[cfg(test)] - pub fn from_credential(credential: Arc, scope: String) -> Self { - Self { credential, scope } - } - - async fn get_token_internal(&self) -> Result { - let token_response = self - .credential - .get_token( - &[&self.scope], - Some(azure_core::credentials::TokenRequestOptions::default()), - ) - .await - .map_err(Error::token_acquisition)?; - - Ok(token_response) - } - - pub async fn get_token(&mut self) -> Result { - let mut attempt = 0_i32; - loop { - attempt += 1; - - match self.get_token_internal().await { - Ok(token) => { - println!( - "[AzureMonitorExporter] Obtained access token, expires on {}", - token.expires_on - ); - return Ok(token); - } - Err(e) => { - println!( - "[AzureMonitorExporter] Failed to obtain access token (attempt {}): {e}", - attempt - ); - } - } - - // Calculate exponential backoff: 5s, 10s, 20s, 30s (capped) - let base_delay_secs = MIN_RETRY_DELAY_SECS * 2.0_f64.powi(attempt - 1); - let capped_delay_secs = base_delay_secs.min(MAX_RETRY_DELAY_SECS); - - // Add jitter: random value between -10% and +10% of the delay - let jitter_range = capped_delay_secs * MAX_RETRY_JITTER_RATIO; - let jitter = if jitter_range > 0.0 { - let random_factor = rand::random::() * 2.0 - 1.0; - random_factor * jitter_range - } else { - 0.0 - }; - - let delay_secs = (capped_delay_secs + jitter).max(1.0); - let delay = tokio::time::Duration::from_secs_f64(delay_secs); - - println!("[AzureMonitorExporter] Retrying in {:.1}s...", delay_secs); - tokio::time::sleep(delay).await; - } - } - - fn create_credential(auth_config: &AuthConfig) -> Result, Error> { - match auth_config.method { - AuthMethod::ManagedIdentity => { - let mut options = ManagedIdentityCredentialOptions::default(); - - if let Some(client_id) = &auth_config.client_id { - println!("Using user-assigned managed identity with client_id: {client_id}"); - options.user_assigned_id = Some(UserAssignedId::ClientId(client_id.clone())); - } else { - println!("Using system-assigned managed identity"); - } - - Ok(ManagedIdentityCredential::new(Some(options)) - .map_err(|e| Error::create_credential(AuthMethod::ManagedIdentity, e))?) - } - AuthMethod::Development => { - println!("Using developer tools credential (Azure CLI / Azure Developer CLI)"); - Ok( - DeveloperToolsCredential::new(Some(DeveloperToolsCredentialOptions::default())) - .map_err(|e| Error::create_credential(AuthMethod::Development, e))?, - ) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use azure_core::credentials::TokenRequestOptions; - use azure_core::time::OffsetDateTime; - use std::sync::atomic::{AtomicUsize, Ordering}; - - #[derive(Debug)] - struct MockCredential { - token: String, - expires_in: azure_core::time::Duration, - call_count: Arc, - } - - fn make_mock_credential( - token: &str, - expires_in: azure_core::time::Duration, - call_count: Arc, - ) -> Arc { - let cred: Arc = Arc::new(MockCredential { - token: token.to_string(), - expires_in, - call_count, - }); - cred - } - - #[async_trait::async_trait] - impl TokenCredential for MockCredential { - async fn get_token( - &self, - _scopes: &[&str], - _options: Option>, - ) -> azure_core::Result { - let _ = self.call_count.fetch_add(1, Ordering::SeqCst); - - Ok(AccessToken { - token: self.token.clone().into(), - expires_on: OffsetDateTime::now_utc() + self.expires_in, - }) - } - } - - // ==================== Construction Tests ==================== - - #[tokio::test] - async fn test_from_credential_creates_auth() { - let credential = make_mock_credential( - "test_token", - azure_core::time::Duration::minutes(60), - Arc::new(AtomicUsize::new(0)), - ); - - let auth = Auth::from_credential(credential, "test_scope".to_string()); - assert_eq!(auth.scope, "test_scope"); - } - - #[tokio::test] - async fn test_new_with_managed_identity_user_assigned() { - let auth_config = AuthConfig { - method: AuthMethod::ManagedIdentity, - client_id: Some("test-client-id".to_string()), - scope: "https://test.scope".to_string(), - }; - - let auth = Auth::new(&auth_config); - assert!(auth.is_ok()); - let auth = auth.unwrap(); - assert_eq!(auth.scope, "https://test.scope"); - } - - #[tokio::test] - async fn test_new_with_managed_identity_system_assigned() { - let auth_config = AuthConfig { - method: AuthMethod::ManagedIdentity, - client_id: None, - scope: "https://test.scope".to_string(), - }; - - let auth = Auth::new(&auth_config); - assert!(auth.is_ok()); - } - - #[tokio::test] - async fn test_new_with_development_auth() { - let auth_config = AuthConfig { - method: AuthMethod::Development, - client_id: None, - scope: "https://test.scope".to_string(), - }; - - // May fail if Azure CLI not installed - both outcomes are valid - let result = Auth::new(&auth_config); - match result { - Ok(auth) => assert_eq!(auth.scope, "https://test.scope"), - Err(Error::Auth { - kind: super::super::error::AuthErrorKind::CreateCredential { method }, - .. - }) => { - assert_eq!(method, AuthMethod::Development); - } - Err(err) => panic!("Unexpected error type: {:?}", err), - } - } - - // ==================== Token Fetching Tests ==================== - - #[tokio::test] - async fn test_get_token_internal_returns_valid_token() { - let call_count = Arc::new(AtomicUsize::new(0)); - let credential = make_mock_credential( - "test_token", - azure_core::time::Duration::minutes(60), - call_count.clone(), - ); - - let auth = Auth::from_credential(credential, "scope".to_string()); - - let token = auth.get_token_internal().await.unwrap(); - assert_eq!(token.token.secret(), "test_token"); - assert_eq!(call_count.load(Ordering::SeqCst), 1); - } - - #[tokio::test] - async fn test_get_token_internal_calls_credential_each_time() { - let call_count = Arc::new(AtomicUsize::new(0)); - let credential = make_mock_credential( - "test_token", - azure_core::time::Duration::minutes(60), - call_count.clone(), - ); - - let auth = Auth::from_credential(credential, "scope".to_string()); - - // Each call to get_token_internal should call the credential - let _ = auth.get_token_internal().await.unwrap(); - assert_eq!(call_count.load(Ordering::SeqCst), 1); - - let _ = auth.get_token_internal().await.unwrap(); - assert_eq!(call_count.load(Ordering::SeqCst), 2); - - let _ = auth.get_token_internal().await.unwrap(); - assert_eq!(call_count.load(Ordering::SeqCst), 3); - } - - #[tokio::test] - async fn test_get_token_internal_returns_cloned_tokens() { - let credential = make_mock_credential( - "test_token", - azure_core::time::Duration::minutes(60), - Arc::new(AtomicUsize::new(0)), - ); - - let auth = Auth::from_credential(credential, "scope".to_string()); - - let token1 = auth.get_token_internal().await.unwrap(); - let token2 = auth.get_token_internal().await.unwrap(); - - // Same value from both calls - assert_eq!(token1.token.secret(), token2.token.secret()); - } - - // ==================== Error Handling Tests ==================== - - #[tokio::test] - async fn test_get_token_internal_propagates_credential_error() { - #[derive(Debug)] - struct FailingCredential; - - #[async_trait::async_trait] - impl TokenCredential for FailingCredential { - async fn get_token( - &self, - _scopes: &[&str], - _options: Option>, - ) -> azure_core::Result { - Err(azure_core::error::Error::new( - azure_core::error::ErrorKind::Credential, - "Mock credential failure", - )) - } - } - - let cred = FailingCredential; - let credential: Arc = Arc::new(cred); - let auth = Auth::from_credential(credential, "scope".to_string()); - - let result = auth.get_token_internal().await; - assert!(result.is_err()); - match result.unwrap_err() { - Error::Auth { - kind: super::super::error::AuthErrorKind::TokenAcquisition, - .. - } => {} - err => panic!("Expected Auth token acquisition error, got: {:?}", err), - } - } - - // ==================== Clone Behavior Tests ==================== - - #[tokio::test] - async fn test_cloned_auth_shares_credential() { - let call_count = Arc::new(AtomicUsize::new(0)); - let credential = make_mock_credential( - "test_token", - azure_core::time::Duration::minutes(60), - call_count.clone(), - ); - - let auth1 = Auth::from_credential(credential, "scope".to_string()); - let auth2 = auth1.clone(); - - // Both auth instances share the same credential - let _ = auth1.get_token_internal().await.unwrap(); - assert_eq!(call_count.load(Ordering::SeqCst), 1); - - let _ = auth2.get_token_internal().await.unwrap(); - assert_eq!(call_count.load(Ordering::SeqCst), 2); - } -} diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/config.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/config.rs index 068621bd47..ff59ba3fdd 100644 --- a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/config.rs +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/config.rs @@ -13,54 +13,10 @@ pub struct Config { /// API configuration for Azure Monitor pub api: ApiConfig, - /// Authentication configuration + /// Name of the authentication extension to use for token acquisition. + /// This should match the name of an Azure Identity Auth Extension configured in the pipeline. #[serde(default)] - pub auth: AuthConfig, -} - -/// Authentication method for Azure -#[derive(Debug, Deserialize, Clone, PartialEq, Default)] -#[serde(rename_all = "lowercase")] -pub enum AuthMethod { - /// Use Managed Identity (system or user-assigned with client_id) - #[serde(alias = "msi", alias = "managed_identity")] - #[default] - ManagedIdentity, - - /// Use developer tools (Azure CLI, Azure Developer CLI) - #[serde(alias = "dev", alias = "developer", alias = "cli")] - Development, -} - -/// Authentication configuration for Azure -#[derive(Debug, Deserialize, Clone)] -pub struct AuthConfig { - /// Authentication method to use - #[serde(default)] - pub method: AuthMethod, - - /// Client ID for user-assigned managed identity (optional) - /// Only used when method is ManagedIdentity - /// If not provided with ManagedIdentity, system-assigned identity will be used - pub client_id: Option, - - /// OAuth scope for token acquisition (defaults to "https://monitor.azure.com/.default") - #[serde(default = "default_scope")] - pub scope: String, -} - -impl Default for AuthConfig { - fn default() -> Self { - Self { - method: AuthMethod::default(), - client_id: None, - scope: default_scope(), - } - } -} - -fn default_scope() -> String { - "https://monitor.azure.com/.default".to_string() + pub auth: String, } /// API configuration for connecting to Azure Monitor @@ -99,13 +55,6 @@ pub struct SchemaConfig { impl Config { /// Validate the configuration pub fn validate(&self) -> Result<(), Error> { - // Validate auth configuration - if self.auth.scope.is_empty() { - return Err(Error::Config( - "Invalid configuration: auth scope must be non-empty".to_string(), - )); - } - // Validate API configuration if self.api.dcr_endpoint.is_empty() { return Err(Error::Config( @@ -194,11 +143,7 @@ mod tests { dcr: "mydcr".to_string(), schema: SchemaConfig::default(), }, - auth: AuthConfig { - scope: "https://monitor.azure.com/.default".to_string(), - client_id: Some("myclientid".to_string()), - method: AuthMethod::ManagedIdentity, - }, + auth: "azure_identity_auth".to_string(), }; assert!(config.validate().is_ok()); @@ -213,7 +158,7 @@ mod tests { dcr: "".to_string(), schema: SchemaConfig::default(), }, - auth: AuthConfig::default(), + auth: String::new(), }; let result = config.validate(); @@ -241,7 +186,7 @@ mod tests { ]), }, }, - auth: AuthConfig::default(), + auth: String::new(), }; let result = config.validate(); @@ -281,7 +226,7 @@ mod tests { ]), }, }, - auth: AuthConfig::default(), + auth: String::new(), }; let result = config.validate(); @@ -310,7 +255,7 @@ mod tests { )]), }, }, - auth: AuthConfig::default(), + auth: String::new(), }; let result = config.validate(); diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/error.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/error.rs index f454b61adc..c1285b3145 100644 --- a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/error.rs +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/error.rs @@ -1,7 +1,6 @@ // Copyright The OpenTelemetry Authors // SPDX-License-Identifier: Apache-2.0 -use super::config::AuthMethod; use http::StatusCode; use http::header::InvalidHeaderValue; @@ -124,7 +123,7 @@ pub enum Error { /// Failed to create auth handler. #[error("Failed to create auth handler")] - AuthHandlerCreation(#[source] Box), + AuthHandlerCreation(#[source] Box), /// Client pool initialization failed. #[error("Client pool initialization failed")] @@ -149,12 +148,6 @@ pub enum Error { /// Authentication error classification. #[derive(Debug, Clone)] pub enum AuthErrorKind { - /// Failed to create credential (during setup). - CreateCredential { method: AuthMethod }, - /// Failed to acquire token. - TokenAcquisition, - /// Token refresh failed during retry. - TokenRefresh, /// Server returned 401. Unauthorized, /// Server returned 403. @@ -164,9 +157,6 @@ pub enum AuthErrorKind { impl std::fmt::Display for AuthErrorKind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::CreateCredential { method } => write!(f, "create credential: {method:?}"), - Self::TokenAcquisition => write!(f, "token acquisition"), - Self::TokenRefresh => write!(f, "token refresh"), Self::Unauthorized => write!(f, "unauthorized"), Self::Forbidden => write!(f, "forbidden"), } @@ -242,26 +232,6 @@ impl Error { Self::Network { kind, source } } - /// Creates a credential creation error. - #[must_use] - pub fn create_credential(method: AuthMethod, source: azure_core::error::Error) -> Self { - Self::Auth { - kind: AuthErrorKind::CreateCredential { method }, - source: Some(source), - body: None, - } - } - - /// Creates a token acquisition error. - #[must_use] - pub fn token_acquisition(source: azure_core::error::Error) -> Self { - Self::Auth { - kind: AuthErrorKind::TokenAcquisition, - source: Some(source), - body: None, - } - } - /// Creates an unauthorized (401) error. #[must_use] pub fn unauthorized(body: String) -> Self { @@ -328,31 +298,6 @@ mod tests { // ==================== Auth Error Tests ==================== - #[test] - fn test_auth_create_credential_message() { - let azure_error = azure_core::error::Error::with_message( - azure_core::error::ErrorKind::Credential, - "managed identity not available", - ); - let error = Error::create_credential(AuthMethod::ManagedIdentity, azure_error); - assert_eq!( - error.to_string(), - "Auth error (create credential: ManagedIdentity)" - ); - assert!(error.source().is_some()); - } - - #[test] - fn test_auth_token_acquisition_message() { - let azure_error = azure_core::error::Error::with_message( - azure_core::error::ErrorKind::Credential, - "token expired", - ); - let error = Error::token_acquisition(azure_error); - assert_eq!(error.to_string(), "Auth error (token acquisition)"); - assert!(error.source().is_some()); - } - #[test] fn test_auth_unauthorized_message() { let error = Error::unauthorized("invalid token".to_string()); @@ -456,15 +401,28 @@ mod tests { } .is_retryable() ); - assert!( - !Error::token_acquisition(azure_core::error::Error::with_message( - azure_core::error::ErrorKind::Credential, - "test" - )) - .is_retryable() + } + + #[test] + fn test_retry_after_server_error() { + let error = Error::ServerError { + status: StatusCode::SERVICE_UNAVAILABLE, + body: String::new(), + retry_after: Some(std::time::Duration::from_secs(60)), + }; + assert_eq!( + error.retry_after(), + Some(std::time::Duration::from_secs(60)) ); } + #[test] + fn test_retry_after_returns_none_for_other_errors() { + assert!(Error::PayloadTooLarge.retry_after().is_none()); + assert!(Error::unauthorized(String::new()).retry_after().is_none()); + assert!(Error::Config("test".to_string()).retry_after().is_none()); + } + // ==================== Display Tests ==================== #[test] @@ -478,18 +436,6 @@ mod tests { #[test] fn test_auth_error_kind_display() { - assert_eq!( - AuthErrorKind::CreateCredential { - method: AuthMethod::ManagedIdentity - } - .to_string(), - "create credential: ManagedIdentity" - ); - assert_eq!( - AuthErrorKind::TokenAcquisition.to_string(), - "token acquisition" - ); - assert_eq!(AuthErrorKind::TokenRefresh.to_string(), "token refresh"); assert_eq!(AuthErrorKind::Unauthorized.to_string(), "unauthorized"); assert_eq!(AuthErrorKind::Forbidden.to_string(), "forbidden"); } diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/exporter.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/exporter.rs index 63f624ae97..c598ed86c4 100644 --- a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/exporter.rs +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/exporter.rs @@ -2,12 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 use async_trait::async_trait; -use azure_core::credentials::AccessToken; use otap_df_channel::error::RecvError; use otap_df_config::SignalType; use otap_df_engine::ConsumerEffectHandlerExtension; use otap_df_engine::control::{AckMsg, NackMsg, NodeControlMsg}; use otap_df_engine::error::Error as EngineError; +use otap_df_engine::extensions::BearerTokenProvider; use otap_df_engine::local::exporter::{EffectHandler, Exporter}; use otap_df_engine::message::{Message, MessageChannel}; use otap_df_engine::terminal_state::TerminalState; @@ -17,7 +17,6 @@ use otap_df_pdata::views::otap::OtapLogsView; use otap_df_pdata::views::otlp::bytes::logs::RawLogsData; use otap_df_pdata::{OtapArrowRecords, OtapPayload}; -use super::auth::Auth; use super::client::LogsIngestionClientPool; use super::config::Config; use super::error::Error; @@ -35,12 +34,6 @@ const MAX_IN_FLIGHT_EXPORTS: usize = 16; const PERIODIC_EXPORT_INTERVAL: u64 = 3; const STATS_PRINT_INTERVAL: u64 = 3; const HEARTBEAT_INTERVAL_SECONDS: u64 = 60; -/// Minimum interval between token refresh attempts (10 seconds). -const MIN_TOKEN_REFRESH_INTERVAL_SECS: u64 = 10; -/// Buffer time before token expiry to trigger a refresh. -/// Azure Identity SDK caches tokens internally and won't issue a new token -/// until ~5 minutes before expiry, so we schedule refresh at 295 seconds before expiry. -const TOKEN_EXPIRY_BUFFER_SECS: u64 = 295; /// Azure Monitor exporter. pub struct AzureMonitorExporter { @@ -312,25 +305,6 @@ impl AzureMonitorExporter { Ok(()) } - #[inline] - fn get_next_token_refresh(token: AccessToken) -> tokio::time::Instant { - let now = azure_core::time::OffsetDateTime::now_utc(); - let duration_remaining = if token.expires_on > now { - (token.expires_on - now).unsigned_abs() - } else { - std::time::Duration::ZERO - }; - - let token_valid_until = tokio::time::Instant::now() + duration_remaining; - let next_token_refresh = - token_valid_until - tokio::time::Duration::from_secs(TOKEN_EXPIRY_BUFFER_SECS); - std::cmp::max( - next_token_refresh, - tokio::time::Instant::now() - + tokio::time::Duration::from_secs(MIN_TOKEN_REFRESH_INTERVAL_SECS), - ) - } - async fn handle_message( &mut self, effect_handler: &EffectHandler, @@ -424,12 +398,14 @@ impl Exporter for AzureMonitorExporter { let mut msg_id = 0; - let mut auth = Auth::new(&self.config.auth).map_err(|e| { - let error = Error::AuthHandlerCreation(Box::new(e)); - EngineError::InternalError { - message: error.to_string(), - } - })?; + let auth = effect_handler + .get_extension::(self.config.auth.as_str()) + .map_err(|e| { + let error = Error::AuthHandlerCreation(Box::new(e)); + EngineError::InternalError { + message: error.to_string(), + } + })?; self.client_pool .initialize(&self.config.api) @@ -441,7 +417,30 @@ impl Exporter for AzureMonitorExporter { } })?; - let mut next_token_refresh = tokio::time::Instant::now(); + // Subscribe to token refresh events from the auth extension + let mut token_rx = auth.subscribe_token_refresh(); + + // Wait for the initial token - blocks until the auth extension provides one + println!("[AzureMonitorExporter] Waiting for initial auth token..."); + let _ = + token_rx + .wait_for(|t| t.is_some()) + .await + .map_err(|_| EngineError::InternalError { + message: "Auth extension closed before providing a token".to_string(), + })?; + + // Now we're guaranteed to have a token + if let Some(token) = token_rx.borrow().as_ref() { + let header = HeaderValue::from_str(&format!("Bearer {}", token.token.secret())) + .map_err(|e| EngineError::InternalError { + message: format!("Failed to create auth header: {:?}", e), + })?; + self.client_pool.update_auth(header.clone()); + self.heartbeat.update_auth(header); + println!("[AzureMonitorExporter] Initial auth token set"); + } + let mut next_stats_print = tokio::time::Instant::now() + tokio::time::Duration::from_secs(STATS_PRINT_INTERVAL); let mut next_periodic_export = tokio::time::Instant::now() @@ -455,37 +454,18 @@ impl Exporter for AzureMonitorExporter { tokio::select! { biased; - _ = tokio::time::sleep_until(next_token_refresh) => { - match auth.get_token().await { - Ok(access_token) => { - match HeaderValue::from_str(&format!("Bearer {}", access_token.token.secret())) { - Ok(header) => { - self.client_pool.update_auth(header.clone()); - self.heartbeat.update_auth(header.clone()); - - // Schedule next token refresh - next_token_refresh = Self::get_next_token_refresh(access_token); - - let refresh_in = next_token_refresh.saturating_duration_since(tokio::time::Instant::now()); - let total_secs = refresh_in.as_secs(); - let hours = total_secs / 3600; - let minutes = (total_secs % 3600) / 60; - let seconds = total_secs % 60; - - println!("[AzureMonitorExporter] Access token refreshed, next refresh scheduled in {}h {}m {}s", hours, minutes, seconds); - } - Err(e) => { - println!("[AzureMonitorExporter] Failed to create auth header: {:?}", e); - // Retry every 10 seconds - next_token_refresh = tokio::time::Instant::now() + tokio::time::Duration::from_secs(10); - } + // React to token refresh events from the auth extension + _ = token_rx.changed() => { + if let Some(token) = token_rx.borrow_and_update().as_ref() { + match HeaderValue::from_str(&format!("Bearer {}", token.token.secret())) { + Ok(header) => { + self.client_pool.update_auth(header.clone()); + self.heartbeat.update_auth(header); + println!("[AzureMonitorExporter] Auth token refreshed"); + } + Err(e) => { + println!("[AzureMonitorExporter] Failed to create auth header: {:?}", e); } - - } - Err(e) => { - println!("[AzureMonitorExporter] Failed to refresh access token: {:?}", e); - // Retry every 10 seconds - next_token_refresh = tokio::time::Instant::now() + tokio::time::Duration::from_secs(10); } } } @@ -607,10 +587,9 @@ exports | in_flight={} stats_time={:?} #[cfg(test)] mod tests { - use super::super::config::{ApiConfig, AuthConfig, SchemaConfig}; + use super::super::config::{ApiConfig, SchemaConfig}; use super::*; use crate::pdata::Context; - use azure_core::time::OffsetDateTime; use bytes::Bytes; use http::StatusCode; use otap_df_engine::local::exporter::EffectHandler; @@ -630,7 +609,7 @@ mod tests { log_record_mapping: HashMap::new(), }, }, - auth: AuthConfig::default(), + auth: "azure_identity_auth".to_string(), } } @@ -640,31 +619,6 @@ mod tests { let _ = AzureMonitorExporter::new(config).unwrap(); } - #[test] - fn test_get_next_token_refresh_logic() { - let now = OffsetDateTime::now_utc(); - let expires_on = now + azure_core::time::Duration::seconds(3600); - - let token = AccessToken { - token: "secret".into(), - expires_on, - }; - - let refresh_at = AzureMonitorExporter::get_next_token_refresh(token); - let duration_until_refresh = refresh_at.duration_since(tokio::time::Instant::now()); - - // Should be 3600 - 295 = 3305 seconds before refresh - // Allow some delta for execution time - let expected = 3305.0; - let actual = duration_until_refresh.as_secs_f64(); - assert!( - (actual - expected).abs() < 5.0, - "Expected ~{}, got {}", - expected, - actual - ); - } - #[tokio::test] async fn test_handle_export_success() { let config = create_test_config(); diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/mod.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/mod.rs index 6aa482edbf..a385ddc876 100644 --- a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/mod.rs +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/mod.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use crate::OTAP_EXPORTER_FACTORIES; use crate::pdata::OtapPdata; -mod auth; mod client; mod config; mod error; diff --git a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/transformer.rs b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/transformer.rs index 726b7fac4d..fdd6d1e034 100644 --- a/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/transformer.rs +++ b/rust/otap-dataflow/crates/otap/src/experimental/azure_monitor_exporter/transformer.rs @@ -387,7 +387,7 @@ mod tests { use std::collections::HashMap; fn create_test_config() -> Config { - use super::super::config::{ApiConfig, AuthConfig, SchemaConfig}; + use super::super::config::{ApiConfig, SchemaConfig}; Config { api: ApiConfig { @@ -407,7 +407,7 @@ mod tests { ]), }, }, - auth: AuthConfig::default(), + auth: "azure_identity_auth".to_string(), } } @@ -695,7 +695,7 @@ mod tests { #[test] fn test_empty_schema_mappings() { - use super::super::config::{ApiConfig, AuthConfig, SchemaConfig}; + use super::super::config::{ApiConfig, SchemaConfig}; let config = Config { api: ApiConfig { @@ -708,7 +708,7 @@ mod tests { log_record_mapping: HashMap::new(), }, }, - auth: AuthConfig::default(), + auth: "azure_identity_auth".to_string(), }; let transformer = Transformer::new(&config); diff --git a/rust/otap-dataflow/crates/otap/src/experimental/mod.rs b/rust/otap-dataflow/crates/otap/src/experimental/mod.rs index e6224643c2..f75df9598f 100644 --- a/rust/otap-dataflow/crates/otap/src/experimental/mod.rs +++ b/rust/otap-dataflow/crates/otap/src/experimental/mod.rs @@ -1,9 +1,9 @@ // Copyright The OpenTelemetry Authors // SPDX-License-Identifier: Apache-2.0 -//! Experimental exporters +//! Experimental exporters, processors, and extensions //! -//! This module contains exporters that are not fully supported +//! This module contains components that are not fully supported //! but are related to project goals mentioned //! in [OTel-Arrow Project Phases](../../../../../../docs/project-phases.md). @@ -15,6 +15,10 @@ pub mod geneva_exporter; #[cfg(feature = "azure-monitor-exporter")] pub mod azure_monitor_exporter; +/// Azure Identity Auth Extension for Azure authentication +#[cfg(feature = "azure-identity-auth-extension")] +pub mod azure_identity_auth_extension; + /// Condense Attributes processor #[cfg(feature = "condense-attributes-processor")] pub mod condense_attributes_processor; diff --git a/rust/otap-dataflow/crates/otap/src/lib.rs b/rust/otap-dataflow/crates/otap/src/lib.rs index 93dc5b4617..d94c21316f 100644 --- a/rust/otap-dataflow/crates/otap/src/lib.rs +++ b/rust/otap-dataflow/crates/otap/src/lib.rs @@ -52,10 +52,11 @@ pub mod noop_exporter; /// An error-exporter returns a static error. pub mod error_exporter; -/// Experimental exporters and processors +/// Experimental exporters, processors, and extensions #[cfg(any( feature = "experimental-exporters", - feature = "experimental-processors" + feature = "experimental-processors", + feature = "experimental-extensions" ))] pub mod experimental;