diff --git a/crates/client/src/async_activity_handle.rs b/crates/client/src/async_activity_handle.rs index 9595060ff..ce32c4abb 100644 --- a/crates/client/src/async_activity_handle.rs +++ b/crates/client/src/async_activity_handle.rs @@ -17,6 +17,18 @@ use temporalio_common::protos::{ }; use tonic::IntoRequest; +/// Generate resource_id for routing based on workflow_id and activity_id. +/// Uses "workflow:workflow_id" when workflow_id is not empty, otherwise "activity:activity_id". +fn generate_resource_id(workflow_id: &str, activity_id: &str) -> String { + if !workflow_id.is_empty() { + format!("workflow:{}", workflow_id) + } else if !activity_id.is_empty() { + format!("activity:{}", activity_id) + } else { + String::new() + } +} + /// Identifies an async activity for completion outside a worker. #[derive(Debug, Clone)] pub enum ActivityIdentifier { @@ -110,7 +122,7 @@ impl AsyncActivityHandle { activity_id: activity_id.clone(), result, identity: self.client.identity(), - resource_id: Default::default(), + resource_id: generate_resource_id(workflow_id, activity_id), } .into_request(), ) @@ -159,7 +171,7 @@ impl AsyncActivityHandle { failure: Some(failure), identity: self.client.identity(), last_heartbeat_details, - resource_id: Default::default(), + resource_id: generate_resource_id(workflow_id, activity_id), } .into_request(), ) @@ -205,6 +217,7 @@ impl AsyncActivityHandle { activity_id: activity_id.clone(), details, identity: self.client.identity(), + resource_id: generate_resource_id(workflow_id, activity_id), ..Default::default() } .into_request(), @@ -233,7 +246,7 @@ impl AsyncActivityHandle { details, identity: self.client.identity(), namespace: self.client.namespace(), - resource_id: Default::default(), + resource_id: String::new(), // Resource id is unavailable for task token } .into_request(), ) @@ -256,7 +269,7 @@ impl AsyncActivityHandle { activity_id: activity_id.clone(), details, identity: self.client.identity(), - resource_id: Default::default(), + resource_id: generate_resource_id(workflow_id, activity_id), } .into_request(), ) diff --git a/crates/client/src/grpc.rs b/crates/client/src/grpc.rs index f67524563..295d7d996 100644 --- a/crates/client/src/grpc.rs +++ b/crates/client/src/grpc.rs @@ -566,10 +566,13 @@ macro_rules! proxier { }; } -macro_rules! namespaced_request { +macro_rules! request_with_headers { ($req:ident) => {{ + use temporalio_common::request_headers::extract_temporal_request_headers; + let ns_str = $req.get_ref().namespace.clone(); - // Attach namespace header + + // Attach namespace header (existing behavior) $req.metadata_mut().insert( TEMPORAL_NAMESPACE_HEADER_KEY, ns_str.parse().unwrap_or_else(|e| { @@ -577,6 +580,27 @@ macro_rules! namespaced_request { AsciiMetadataValue::from_static("") }), ); + + // Extract and attach additional headers from proto annotations + let headers = extract_temporal_request_headers( + $req.get_ref() as &dyn std::any::Any, + Some($req.metadata()), + ); + + for (key, value) in headers { + if let Ok(header_key) = + key.parse::>() + { + if let Ok(header_value) = value.parse::() { + $req.metadata_mut().insert(header_key, header_value); + } else { + warn!("Unable to parse header value for {}: {}", key, value); + } + } else { + warn!("Unable to parse header key: {}", key); + } + } + // Init metric labels AttachMetricLabels::namespace(ns_str) }}; @@ -598,7 +622,7 @@ proxier! { RegisterNamespaceRequest, RegisterNamespaceResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -607,7 +631,7 @@ proxier! { DescribeNamespaceRequest, DescribeNamespaceResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -621,7 +645,7 @@ proxier! { UpdateNamespaceRequest, UpdateNamespaceResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -630,7 +654,7 @@ proxier! { DeprecateNamespaceRequest, DeprecateNamespaceResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -639,7 +663,7 @@ proxier! { StartWorkflowExecutionRequest, StartWorkflowExecutionResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); }, @@ -691,7 +715,7 @@ proxier! { GetWorkflowExecutionHistoryRequest, GetWorkflowExecutionHistoryResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); if r.get_ref().wait_new_event { r.extensions_mut().insert(IsUserLongPoll); @@ -703,7 +727,7 @@ proxier! { GetWorkflowExecutionHistoryReverseRequest, GetWorkflowExecutionHistoryReverseResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -712,7 +736,7 @@ proxier! { PollWorkflowTaskQueueRequest, PollWorkflowTaskQueueResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -722,7 +746,7 @@ proxier! { RespondWorkflowTaskCompletedRequest, RespondWorkflowTaskCompletedResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -731,7 +755,7 @@ proxier! { RespondWorkflowTaskFailedRequest, RespondWorkflowTaskFailedResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -740,7 +764,7 @@ proxier! { PollActivityTaskQueueRequest, PollActivityTaskQueueResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -750,7 +774,7 @@ proxier! { RecordActivityTaskHeartbeatRequest, RecordActivityTaskHeartbeatResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -759,7 +783,7 @@ proxier! { RecordActivityTaskHeartbeatByIdRequest, RecordActivityTaskHeartbeatByIdResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -768,7 +792,7 @@ proxier! { RespondActivityTaskCompletedRequest, RespondActivityTaskCompletedResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -777,7 +801,7 @@ proxier! { RespondActivityTaskCompletedByIdRequest, RespondActivityTaskCompletedByIdResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -787,7 +811,7 @@ proxier! { RespondActivityTaskFailedRequest, RespondActivityTaskFailedResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -796,7 +820,7 @@ proxier! { RespondActivityTaskFailedByIdRequest, RespondActivityTaskFailedByIdResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -805,7 +829,7 @@ proxier! { RespondActivityTaskCanceledRequest, RespondActivityTaskCanceledResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -814,7 +838,7 @@ proxier! { RespondActivityTaskCanceledByIdRequest, RespondActivityTaskCanceledByIdResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -823,7 +847,7 @@ proxier! { RequestCancelWorkflowExecutionRequest, RequestCancelWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -832,7 +856,7 @@ proxier! { SignalWorkflowExecutionRequest, SignalWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -841,7 +865,7 @@ proxier! { SignalWithStartWorkflowExecutionRequest, SignalWithStartWorkflowExecutionResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -851,7 +875,7 @@ proxier! { ResetWorkflowExecutionRequest, ResetWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -860,7 +884,7 @@ proxier! { TerminateWorkflowExecutionRequest, TerminateWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -869,7 +893,7 @@ proxier! { DeleteWorkflowExecutionRequest, DeleteWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -878,7 +902,7 @@ proxier! { ListOpenWorkflowExecutionsRequest, ListOpenWorkflowExecutionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -887,7 +911,7 @@ proxier! { ListClosedWorkflowExecutionsRequest, ListClosedWorkflowExecutionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -896,7 +920,7 @@ proxier! { ListWorkflowExecutionsRequest, ListWorkflowExecutionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -905,7 +929,7 @@ proxier! { ListArchivedWorkflowExecutionsRequest, ListArchivedWorkflowExecutionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -914,7 +938,7 @@ proxier! { ScanWorkflowExecutionsRequest, ScanWorkflowExecutionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -923,7 +947,7 @@ proxier! { CountWorkflowExecutionsRequest, CountWorkflowExecutionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -932,7 +956,7 @@ proxier! { CreateWorkflowRuleRequest, CreateWorkflowRuleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -941,7 +965,7 @@ proxier! { DescribeWorkflowRuleRequest, DescribeWorkflowRuleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -950,7 +974,7 @@ proxier! { DeleteWorkflowRuleRequest, DeleteWorkflowRuleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -959,7 +983,7 @@ proxier! { ListWorkflowRulesRequest, ListWorkflowRulesResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -968,7 +992,7 @@ proxier! { TriggerWorkflowRuleRequest, TriggerWorkflowRuleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -982,7 +1006,7 @@ proxier! { RespondQueryTaskCompletedRequest, RespondQueryTaskCompletedResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -991,7 +1015,7 @@ proxier! { ResetStickyTaskQueueRequest, ResetStickyTaskQueueResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1000,7 +1024,7 @@ proxier! { QueryWorkflowRequest, QueryWorkflowResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1009,7 +1033,7 @@ proxier! { DescribeWorkflowExecutionRequest, DescribeWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1018,7 +1042,7 @@ proxier! { DescribeTaskQueueRequest, DescribeTaskQueueResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -1038,7 +1062,7 @@ proxier! { ListTaskQueuePartitionsRequest, ListTaskQueuePartitionsResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -1048,7 +1072,7 @@ proxier! { CreateScheduleRequest, CreateScheduleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1057,7 +1081,7 @@ proxier! { DescribeScheduleRequest, DescribeScheduleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1066,7 +1090,7 @@ proxier! { UpdateScheduleRequest, UpdateScheduleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1075,7 +1099,7 @@ proxier! { PatchScheduleRequest, PatchScheduleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1084,7 +1108,7 @@ proxier! { ListScheduleMatchingTimesRequest, ListScheduleMatchingTimesResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1093,7 +1117,7 @@ proxier! { DeleteScheduleRequest, DeleteScheduleResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1102,7 +1126,7 @@ proxier! { ListSchedulesRequest, ListSchedulesResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1111,7 +1135,7 @@ proxier! { CountSchedulesRequest, CountSchedulesResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1120,7 +1144,7 @@ proxier! { UpdateWorkerBuildIdCompatibilityRequest, UpdateWorkerBuildIdCompatibilityResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q_str(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -1130,7 +1154,7 @@ proxier! { GetWorkerBuildIdCompatibilityRequest, GetWorkerBuildIdCompatibilityResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q_str(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -1140,7 +1164,7 @@ proxier! { GetWorkerTaskReachabilityRequest, GetWorkerTaskReachabilityResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1149,7 +1173,7 @@ proxier! { UpdateWorkflowExecutionRequest, UpdateWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); let exts = r.extensions_mut(); exts.insert(labels); exts.insert(IsUserLongPoll); @@ -1160,7 +1184,7 @@ proxier! { PollWorkflowExecutionUpdateRequest, PollWorkflowExecutionUpdateResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1169,7 +1193,7 @@ proxier! { StartBatchOperationRequest, StartBatchOperationResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1178,7 +1202,7 @@ proxier! { StopBatchOperationRequest, StopBatchOperationResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1187,7 +1211,7 @@ proxier! { DescribeBatchOperationRequest, DescribeBatchOperationResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1196,7 +1220,7 @@ proxier! { DescribeDeploymentRequest, DescribeDeploymentResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1205,7 +1229,7 @@ proxier! { ListBatchOperationsRequest, ListBatchOperationsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1214,7 +1238,7 @@ proxier! { ListDeploymentsRequest, ListDeploymentsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1223,7 +1247,7 @@ proxier! { ExecuteMultiOperationRequest, ExecuteMultiOperationResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1232,7 +1256,7 @@ proxier! { GetCurrentDeploymentRequest, GetCurrentDeploymentResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1241,7 +1265,7 @@ proxier! { GetDeploymentReachabilityRequest, GetDeploymentReachabilityResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1250,7 +1274,7 @@ proxier! { GetWorkerVersioningRulesRequest, GetWorkerVersioningRulesResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q_str(&r.get_ref().task_queue); r.extensions_mut().insert(labels); } @@ -1260,7 +1284,7 @@ proxier! { UpdateWorkerVersioningRulesRequest, UpdateWorkerVersioningRulesResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q_str(&r.get_ref().task_queue); r.extensions_mut().insert(labels); } @@ -1270,7 +1294,7 @@ proxier! { PollNexusTaskQueueRequest, PollNexusTaskQueueResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -1280,7 +1304,7 @@ proxier! { RespondNexusTaskCompletedRequest, RespondNexusTaskCompletedResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1289,7 +1313,7 @@ proxier! { RespondNexusTaskFailedRequest, RespondNexusTaskFailedResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1298,7 +1322,7 @@ proxier! { SetCurrentDeploymentRequest, SetCurrentDeploymentResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1307,7 +1331,7 @@ proxier! { ShutdownWorkerRequest, ShutdownWorkerResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1316,7 +1340,7 @@ proxier! { UpdateActivityOptionsRequest, UpdateActivityOptionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1325,7 +1349,7 @@ proxier! { PauseActivityRequest, PauseActivityResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1334,7 +1358,7 @@ proxier! { UnpauseActivityRequest, UnpauseActivityResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1343,7 +1367,7 @@ proxier! { UpdateWorkflowExecutionOptionsRequest, UpdateWorkflowExecutionOptionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1352,7 +1376,7 @@ proxier! { ResetActivityRequest, ResetActivityResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1361,7 +1385,7 @@ proxier! { DeleteWorkerDeploymentRequest, DeleteWorkerDeploymentResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1370,7 +1394,7 @@ proxier! { DeleteWorkerDeploymentVersionRequest, DeleteWorkerDeploymentVersionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1379,7 +1403,7 @@ proxier! { DescribeWorkerDeploymentRequest, DescribeWorkerDeploymentResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1388,7 +1412,7 @@ proxier! { DescribeWorkerDeploymentVersionRequest, DescribeWorkerDeploymentVersionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1397,7 +1421,7 @@ proxier! { ListWorkerDeploymentsRequest, ListWorkerDeploymentsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1406,7 +1430,7 @@ proxier! { SetWorkerDeploymentCurrentVersionRequest, SetWorkerDeploymentCurrentVersionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1415,7 +1439,7 @@ proxier! { SetWorkerDeploymentRampingVersionRequest, SetWorkerDeploymentRampingVersionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1424,7 +1448,7 @@ proxier! { UpdateWorkerDeploymentVersionMetadataRequest, UpdateWorkerDeploymentVersionMetadataResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1433,7 +1457,7 @@ proxier! { ListWorkersRequest, ListWorkersResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1442,7 +1466,7 @@ proxier! { RecordWorkerHeartbeatRequest, RecordWorkerHeartbeatResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1451,7 +1475,7 @@ proxier! { UpdateTaskQueueConfigRequest, UpdateTaskQueueConfigResponse, |r| { - let mut labels = namespaced_request!(r); + let mut labels = request_with_headers!(r); labels.task_q_str(r.get_ref().task_queue.clone()); r.extensions_mut().insert(labels); } @@ -1461,7 +1485,7 @@ proxier! { FetchWorkerConfigRequest, FetchWorkerConfigResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1470,7 +1494,7 @@ proxier! { UpdateWorkerConfigRequest, UpdateWorkerConfigResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1479,7 +1503,7 @@ proxier! { DescribeWorkerRequest, DescribeWorkerResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1488,7 +1512,7 @@ proxier! { SetWorkerDeploymentManagerRequest, SetWorkerDeploymentManagerResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1497,7 +1521,7 @@ proxier! { PauseWorkflowExecutionRequest, PauseWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1506,7 +1530,7 @@ proxier! { UnpauseWorkflowExecutionRequest, UnpauseWorkflowExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1515,7 +1539,7 @@ proxier! { StartActivityExecutionRequest, StartActivityExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1524,7 +1548,7 @@ proxier! { DescribeActivityExecutionRequest, DescribeActivityExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1533,7 +1557,7 @@ proxier! { PollActivityExecutionRequest, PollActivityExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1542,7 +1566,7 @@ proxier! { ListActivityExecutionsRequest, ListActivityExecutionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1551,7 +1575,7 @@ proxier! { CountActivityExecutionsRequest, CountActivityExecutionsResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1560,7 +1584,7 @@ proxier! { RequestCancelActivityExecutionRequest, RequestCancelActivityExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1569,7 +1593,7 @@ proxier! { TerminateActivityExecutionRequest, TerminateActivityExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1578,7 +1602,7 @@ proxier! { DeleteActivityExecutionRequest, DeleteActivityExecutionResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1591,7 +1615,7 @@ proxier! { (list_search_attributes, ListSearchAttributesRequest, ListSearchAttributesResponse); (delete_namespace, DeleteNamespaceRequest, DeleteNamespaceResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); @@ -1618,20 +1642,20 @@ proxier! { (get_namespaces, cloudreq::GetNamespacesRequest, cloudreq::GetNamespacesResponse); (get_namespace, cloudreq::GetNamespaceRequest, cloudreq::GetNamespaceResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); (update_namespace, cloudreq::UpdateNamespaceRequest, cloudreq::UpdateNamespaceResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); (rename_custom_search_attribute, cloudreq::RenameCustomSearchAttributeRequest, cloudreq::RenameCustomSearchAttributeResponse); (delete_namespace, cloudreq::DeleteNamespaceRequest, cloudreq::DeleteNamespaceResponse, |r| { - let labels = namespaced_request!(r); + let labels = request_with_headers!(r); r.extensions_mut().insert(labels); } ); diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index 16682489a..0183b0256 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -100,6 +100,7 @@ prost = { workspace = true } prost-types = "0.14" tonic-prost-build = { workspace = true } pbjson-build = { workspace = true } +prost-reflect = "0.15" [lints] workspace = true diff --git a/crates/common/build.rs b/crates/common/build.rs index 45aed80ae..0ba1cdb9a 100644 --- a/crates/common/build.rs +++ b/crates/common/build.rs @@ -1,4 +1,5 @@ use prost::Message; +use prost_reflect::ReflectMessage; use prost_types::{ DescriptorProto, FieldDescriptorProto, FileDescriptorSet, MessageOptions, field_descriptor_proto::{Label, Type}, @@ -58,8 +59,6 @@ fn main() -> Result<(), Box> { let out = PathBuf::from(env::var("OUT_DIR").unwrap()); let descriptor_file = out.join("descriptors.bin"); let mut builder = tonic_prost_build::configure() - // We don't actually want to build the grpc definitions - we don't need them (for now). - // Just build the message structs. .build_server(false) .build_client(true) // Make conversions easier for some types @@ -177,6 +176,7 @@ fn main() -> Result<(), Box> { )?; generate_payload_visitor(&out, &descriptor_file)?; + generate_request_headers(&out, &descriptor_file)?; // TODO [rust-sdk-branch]: support normal JSON and proto JSON serialization let descriptors = std::fs::read(&descriptor_file)?; pbjson_build::Builder::new() @@ -404,18 +404,7 @@ impl PayloadVisitorGenerator { } fn to_map_entry_name(field_name: &str) -> String { - let mut result = String::new(); - let mut capitalize_next = true; - for c in field_name.chars() { - if c == '_' { - capitalize_next = true; - } else if capitalize_next { - result.push(c.to_ascii_uppercase()); - capitalize_next = false; - } else { - result.push(c); - } - } + let mut result = to_pascal_case(field_name); result.push_str("Entry"); result } @@ -598,7 +587,7 @@ impl PayloadVisitorGenerator { } fn generate_impl(&self, proto_name: &str, fields: &[PayloadFieldInfo]) -> String { - let rust_path = self.proto_to_rust_path(proto_name); + let rust_path = proto_to_rust_path(proto_name); let mut impl_body = String::new(); @@ -633,7 +622,7 @@ impl crate::payload_visitor::PayloadVisitable for {rust_path} {{ proto_path: &str, kind: &PayloadFieldKind, ) -> String { - let rust_field = Self::to_snake_case(field_name); + let rust_field = to_snake_case(field_name); match kind { PayloadFieldKind::SinglePayload => { @@ -734,12 +723,12 @@ impl crate::payload_visitor::PayloadVisitable for {rust_path} {{ // Get the full rust path to the oneof enum let enum_path = self.proto_to_rust_oneof_enum_path(parent_proto_name, oneof_name); // The field in the struct is snake_case of the oneof field name - let rust_field = Self::to_snake_case(oneof_name); + let rust_field = to_snake_case(oneof_name); let mut arms = String::new(); for variant in variants { - let variant_name = Self::to_pascal_case(&variant.name); + let variant_name = to_pascal_case(&variant.name); arms.push_str(&format!( " {enum_path}::{variant}(msg) => msg.visit_payloads_mut(visitor).await,\n", enum_path = enum_path, @@ -772,74 +761,71 @@ impl crate::payload_visitor::PayloadVisitable for {rust_path} {{ } } - fn proto_to_rust_path(&self, proto_name: &str) -> String { - let parts: Vec<&str> = proto_name.split('.').collect(); - let mut rust_parts = Vec::new(); - - // Handle the package -> module mapping - for (i, part) in parts.iter().enumerate() { - if i == parts.len() - 1 { - // Last part is the type name - keep PascalCase - rust_parts.push((*part).to_string()); - } else { - // Package parts become snake_case modules - rust_parts.push(Self::to_snake_case(part)); - } - } - - // The protos module structure - let path = rust_parts.join("::"); - - // Map to the actual crate paths - format!("crate::protos::{}", path) - } - fn proto_to_rust_oneof_enum_path(&self, parent_proto_name: &str, oneof_name: &str) -> String { let parts: Vec<&str> = parent_proto_name.split('.').collect(); let mut rust_parts = Vec::new(); // All parts become snake_case modules (struct name becomes a module containing the enum) for part in parts.iter() { - rust_parts.push(Self::to_snake_case(part)); + rust_parts.push(to_snake_case(part)); } let module_path = rust_parts.join("::"); // The enum name is PascalCase of the oneof field name - let enum_name = Self::to_pascal_case(oneof_name); + let enum_name = to_pascal_case(oneof_name); format!("crate::protos::{}::{}", module_path, enum_name) } +} - fn to_snake_case(s: &str) -> String { - let mut result = String::new(); - for (i, c) in s.chars().enumerate() { - if c.is_uppercase() { - if i > 0 { - result.push('_'); - } - result.push(c.to_ascii_lowercase()); - } else { - result.push(c); +fn to_snake_case(s: &str) -> String { + let mut result = String::new(); + for (i, c) in s.chars().enumerate() { + if c.is_uppercase() { + if i > 0 { + result.push('_'); } + result.push(c.to_ascii_lowercase()); + } else { + result.push(c); } - result } + result +} - fn to_pascal_case(s: &str) -> String { - let mut result = String::new(); - let mut capitalize_next = true; - for c in s.chars() { - if c == '_' { - capitalize_next = true; - } else if capitalize_next { - result.push(c.to_ascii_uppercase()); - capitalize_next = false; - } else { - result.push(c); - } +fn to_pascal_case(s: &str) -> String { + let mut result = String::new(); + let mut capitalize_next = true; + for c in s.chars() { + if c == '_' { + capitalize_next = true; + } else if capitalize_next { + result.push(c.to_ascii_uppercase()); + capitalize_next = false; + } else { + result.push(c); + } + } + result +} + +/// Convert a proto fully-qualified name to a Rust type path under `crate::protos::`. +fn proto_to_rust_path(proto_name: &str) -> String { + let parts: Vec<&str> = proto_name.split('.').collect(); + let mut rust_parts = Vec::new(); + + for (i, part) in parts.iter().enumerate() { + if i == parts.len() - 1 { + // Last part is the type name - keep PascalCase + rust_parts.push((*part).to_string()); + } else { + // Package parts become snake_case modules + rust_parts.push(to_snake_case(part)); } - result } + + let path = rust_parts.join("::"); + format!("crate::protos::{}", path) } fn is_message_type(field: &FieldDescriptorProto) -> bool { @@ -855,3 +841,230 @@ fn is_map_entry(options: &Option) -> bool { .as_ref() .is_some_and(|o| o.map_entry.unwrap_or(false)) } + +/// Generate request header extraction implementations by parsing proto descriptors. +fn generate_request_headers( + out_dir: &Path, + descriptor_path: &Path, +) -> Result<(), Box> { + let mut descriptor_bytes = Vec::new(); + File::open(descriptor_path)?.read_to_end(&mut descriptor_bytes)?; + + // Use prost_reflect to parse descriptors with extension support + let descriptor_set = prost_reflect::DescriptorPool::decode(descriptor_bytes.as_slice()) + .map_err(|e| format!("Failed to decode descriptor pool: {}", e))?; + + let mut generator = RequestHeaderGenerator::new(); + generator.process_descriptors_reflect(&descriptor_set)?; + + let output_path = out_dir.join("request_header_impl.rs"); + let mut file = File::create(&output_path)?; + file.write_all(generator.generate().as_bytes())?; + + Ok(()) +} + +#[derive(Debug, Clone)] +struct MethodHeaderInfo { + request_rust_type: String, + headers: Vec, +} + +#[derive(Debug, Clone)] +struct HeaderInfo { + header_name: String, + value_template: String, + field_paths: Vec, +} + +/// Generator for request header extraction implementations. +struct RequestHeaderGenerator { + /// Methods that have request_header annotations + method_headers: Vec, +} + +impl RequestHeaderGenerator { + fn new() -> Self { + Self { + method_headers: Vec::new(), + } + } + + fn process_descriptors_reflect( + &mut self, + descriptor_pool: &prost_reflect::DescriptorPool, + ) -> Result<(), Box> { + let request_header_ext = descriptor_pool + .get_extension_by_name("temporal.api.protometa.v1.request_header") + .ok_or("Could not find request_header extension")?; + + for service in descriptor_pool.services() { + for method in service.methods() { + let method_options = method.options(); + if !method_options.has_extension(&request_header_ext) { + continue; + } + let extension_value = method_options.get_extension(&request_header_ext); + let request_rust_type = + proto_to_rust_path(method.input().full_name()); + + // Collect annotation messages (may be single or repeated) + let messages: Vec<_> = match &*extension_value { + prost_reflect::Value::List(list) => list + .iter() + .filter_map(|v| match v { + prost_reflect::Value::Message(m) => Some(m), + _ => None, + }) + .collect(), + prost_reflect::Value::Message(m) => vec![m], + _ => continue, + }; + + let headers: Vec<_> = messages + .into_iter() + .filter_map(parse_annotation) + .collect(); + + if !headers.is_empty() { + self.method_headers.push(MethodHeaderInfo { + request_rust_type, + headers, + }); + } + } + } + + Ok(()) + } + + fn generate(&self) -> String { + let mut output = String::from( + r#"// Generated from descriptors.bin - DO NOT EDIT + +/// Extract headers from request messages based on proto annotations +pub fn extract_temporal_request_headers( + request: &dyn Any, + existing_metadata: Option<&MetadataMap>, +) -> Vec<(String, String)> { + let mut headers = Vec::new(); + + // Extract headers from proto annotations +"#, + ); + + for method_info in &self.method_headers { + for header in &method_info.headers { + output.push_str(&generate_header_check( + &method_info.request_rust_type, + header, + )); + } + } + + output.push_str(" headers\n}\n"); + output + } +} + +fn parse_annotation(msg: &prost_reflect::DynamicMessage) -> Option { + let header_name = msg + .get_field(&msg.descriptor().get_field_by_name("header")?) + .as_str() + .unwrap_or("") + .to_string(); + let value_template = msg + .get_field(&msg.descriptor().get_field_by_name("value")?) + .as_str() + .unwrap_or("") + .to_string(); + + if header_name.is_empty() || value_template.is_empty() { + return None; + } + + let field_paths = parse_field_paths(&value_template); + Some(HeaderInfo { + header_name, + value_template, + field_paths, + }) +} + +fn parse_field_paths(template: &str) -> Vec { + let mut paths = Vec::new(); + let mut chars = template.chars(); + while let Some(c) = chars.next() { + if c == '{' { + let path: String = chars.by_ref().take_while(|&c| c != '}').collect(); + if !path.is_empty() { + paths.push(path); + } + } + } + paths +} + +fn generate_header_check(request_rust_type: &str, header: &HeaderInfo) -> String { + if header.field_paths.is_empty() { + // Static header value (no field interpolation) + return format!( + r#" if request.downcast_ref::<{type_}>().is_some() + && existing_metadata.is_none_or(|md| md.get("{hdr}").is_none()) {{ + headers.push(("{hdr}".to_string(), "{val}".to_string())); + }} +"#, + type_ = request_rust_type, + hdr = header.header_name, + val = header.value_template + ); + } + + let mut output = String::new(); + for field_path in &header.field_paths { + let parts: Vec<&str> = field_path.split('.').collect(); + let template_str = header + .value_template + .replace(&format!("{{{}}}", field_path), "{}"); + + // Build the if-let chain: downcast, then optional intermediate fields + let mut conditions = format!( + " if let Some(req) = request.downcast_ref::<{}>()", + request_rust_type + ); + let mut current = "req".to_string(); + for (i, part) in parts[..parts.len() - 1].iter().enumerate() { + let binding = format!("f{}", i); + conditions.push_str(&format!( + "\n && let Some({}) = {}.{}.as_ref()", + binding, + current, + to_snake_case(part) + )); + current = binding; + } + let last_field = to_snake_case(parts[parts.len() - 1]); + let val_ref = format!("{}.{}", current, last_field); + + // Value expression: passthrough or format with template + let value_expr = if template_str == "{}" { + format!("{}.to_string()", val_ref) + } else { + format!("format!(\"{}\", {})", template_str, val_ref) + }; + + output.push_str(&format!( + r#"{conditions} + && !{val_ref}.is_empty() + && existing_metadata.is_none_or(|md| md.get("{hdr}").is_none()) {{ + headers.push(("{hdr}".to_string(), {value_expr})); + }} +"#, + conditions = conditions, + val_ref = val_ref, + hdr = header.header_name, + value_expr = value_expr + )); + } + output +} diff --git a/crates/common/protos/api_upstream/.github/PULL_REQUEST_TEMPLATE.md b/crates/common/protos/api_upstream/.github/PULL_REQUEST_TEMPLATE.md index 5d6bf3cab..c26236805 100644 --- a/crates/common/protos/api_upstream/.github/PULL_REQUEST_TEMPLATE.md +++ b/crates/common/protos/api_upstream/.github/PULL_REQUEST_TEMPLATE.md @@ -1,5 +1,3 @@ -_**READ BEFORE MERGING:** All PRs require approval by both Server AND SDK teams before merging! This is why the number of required approvals is "2" and not "1"--two reviewers from the same team is NOT sufficient. If your PR is not approved by someone in BOTH teams, it may be summarily reverted._ - **What changed?** diff --git a/crates/common/protos/api_upstream/openapi/openapiv2.json b/crates/common/protos/api_upstream/openapi/openapiv2.json index ebcbea209..ddbc48c61 100644 --- a/crates/common/protos/api_upstream/openapi/openapiv2.json +++ b/crates/common/protos/api_upstream/openapi/openapiv2.json @@ -13044,6 +13044,10 @@ }, "description": "Links associated with the event." }, + "principal": { + "$ref": "#/definitions/v1Principal", + "description": "Server-computed authenticated caller identity associated with this event." + }, "workflowExecutionStartedEventAttributes": { "$ref": "#/definitions/v1WorkflowExecutionStartedEventAttributes" }, @@ -14634,6 +14638,20 @@ }, "description": "PostResetOperation represents an operation to be performed on the new workflow execution after a workflow reset." }, + "v1Principal": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "Low-cardinality category of the principal (e.g., \"jwt\", \"users\")." + }, + "name": { + "type": "string", + "description": "Identifier within that category (e.g., sub JWT claim, email address)." + } + }, + "description": "Principal is an authenticated caller identity computed by the server from trusted\nauthentication context." + }, "v1Priority": { "type": "object", "properties": { diff --git a/crates/common/protos/api_upstream/openapi/openapiv3.yaml b/crates/common/protos/api_upstream/openapi/openapiv3.yaml index 091a9998e..cd503dccd 100644 --- a/crates/common/protos/api_upstream/openapi/openapiv3.yaml +++ b/crates/common/protos/api_upstream/openapi/openapiv3.yaml @@ -10193,6 +10193,10 @@ components: items: $ref: '#/components/schemas/Link' description: Links associated with the event. + principal: + allOf: + - $ref: '#/components/schemas/Principal' + description: Server-computed authenticated caller identity associated with this event. workflowExecutionStartedEventAttributes: $ref: '#/components/schemas/WorkflowExecutionStartedEventAttributes' workflowExecutionCompletedEventAttributes: @@ -11727,6 +11731,18 @@ components: description: |- UpdateWorkflowOptions represents updating workflow execution options after a workflow reset. Keep the parameters in sync with temporal.api.workflowservice.v1.UpdateWorkflowExecutionOptionsRequest. + Principal: + type: object + properties: + type: + type: string + description: Low-cardinality category of the principal (e.g., "jwt", "users"). + name: + type: string + description: Identifier within that category (e.g., sub JWT claim, email address). + description: |- + Principal is an authenticated caller identity computed by the server from trusted + authentication context. Priority: type: object properties: diff --git a/crates/common/protos/api_upstream/temporal/api/common/v1/message.proto b/crates/common/protos/api_upstream/temporal/api/common/v1/message.proto index aa5b3f370..88c3a834e 100644 --- a/crates/common/protos/api_upstream/temporal/api/common/v1/message.proto +++ b/crates/common/protos/api_upstream/temporal/api/common/v1/message.proto @@ -246,6 +246,15 @@ message Link { } } +// Principal is an authenticated caller identity computed by the server from trusted +// authentication context. +message Principal { + // Low-cardinality category of the principal (e.g., "jwt", "users"). + string type = 1; + // Identifier within that category (e.g., sub JWT claim, email address). + string name = 2; +} + // Priority contains metadata that controls relative ordering of task processing // when tasks are backed up in a queue. Initially, Priority will be used in // matching (workflow and activity) task queues. Later it may be used in history diff --git a/crates/common/protos/api_upstream/temporal/api/history/v1/message.proto b/crates/common/protos/api_upstream/temporal/api/history/v1/message.proto index 21fb13c5e..34a4286eb 100644 --- a/crates/common/protos/api_upstream/temporal/api/history/v1/message.proto +++ b/crates/common/protos/api_upstream/temporal/api/history/v1/message.proto @@ -1136,6 +1136,8 @@ message HistoryEvent { temporal.api.sdk.v1.UserMetadata user_metadata = 301; // Links associated with the event. repeated temporal.api.common.v1.Link links = 302; + // Server-computed authenticated caller identity associated with this event. + temporal.api.common.v1.Principal principal = 303; // The event details. The type must match that in `event_type`. oneof attributes { WorkflowExecutionStartedEventAttributes workflow_execution_started_event_attributes = 6; diff --git a/crates/common/protos/api_upstream/temporal/api/workflowservice/v1/service.proto b/crates/common/protos/api_upstream/temporal/api/workflowservice/v1/service.proto index cbd71e2ba..18e579c2c 100644 --- a/crates/common/protos/api_upstream/temporal/api/workflowservice/v1/service.proto +++ b/crates/common/protos/api_upstream/temporal/api/workflowservice/v1/service.proto @@ -121,7 +121,7 @@ service WorkflowService { rpc ExecuteMultiOperation (ExecuteMultiOperationRequest) returns (ExecuteMultiOperationResponse) { option (temporal.api.protometa.v1.request_header) = { header: "temporal-resource-id" - value: "workflow:{resource_id}" + value: "{resource_id}" }; } @@ -180,7 +180,7 @@ service WorkflowService { rpc RespondWorkflowTaskCompleted (RespondWorkflowTaskCompletedRequest) returns (RespondWorkflowTaskCompletedResponse) { option (temporal.api.protometa.v1.request_header) = { header: "temporal-resource-id" - value: "workflow:{resource_id}" + value: "{resource_id}" }; } @@ -199,7 +199,7 @@ service WorkflowService { rpc RespondWorkflowTaskFailed (RespondWorkflowTaskFailedRequest) returns (RespondWorkflowTaskFailedResponse) { option (temporal.api.protometa.v1.request_header) = { header: "temporal-resource-id" - value: "workflow:{resource_id}" + value: "{resource_id}" }; } @@ -1460,7 +1460,7 @@ service WorkflowService { }; option (temporal.api.protometa.v1.request_header) = { header: "temporal-resource-id" - value: "worker:{resource_id}" + value: "{resource_id}" }; }; @@ -1505,7 +1505,7 @@ service WorkflowService { }; option (temporal.api.protometa.v1.request_header) = { header: "temporal-resource-id" - value: "worker:{resource_id}" + value: "{resource_id}" }; } @@ -1523,7 +1523,7 @@ service WorkflowService { }; option (temporal.api.protometa.v1.request_header) = { header: "temporal-resource-id" - value: "worker:{resource_id}" + value: "{resource_id}" }; } diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 3b2e92053..a1b220caf 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -16,6 +16,7 @@ pub mod fsm_trait; pub mod payload_visitor; mod priority; pub mod protos; +pub mod request_headers; pub mod telemetry; pub mod worker; mod workflow_definition; diff --git a/crates/common/src/request_headers.rs b/crates/common/src/request_headers.rs new file mode 100644 index 000000000..5a146ea4e --- /dev/null +++ b/crates/common/src/request_headers.rs @@ -0,0 +1,244 @@ +//! Request header extraction functionality. +//! +//! This module provides functionality to automatically extract field values from request messages +//! and convert them to HTTP headers based on protobuf annotations. + +use std::any::Any; +use tonic::metadata::MetadataMap; + +// Include the generated implementation +include!(concat!(env!("OUT_DIR"), "/request_header_impl.rs")); + +#[cfg(test)] +mod tests { + use super::*; + use crate::protos::temporal::api::workflowservice::v1::*; + + /// Helper to extract the temporal-resource-id header value from a request. + fn extract_resource_id(request: &dyn Any) -> Option { + extract_temporal_request_headers(request, None) + .into_iter() + .find(|(k, _)| k == "temporal-resource-id") + .map(|(_, v)| v) + } + + #[test] + fn existing_metadata_prevents_duplicate_header() { + let request = StartWorkflowExecutionRequest { + workflow_id: "wf-1".to_string(), + ..Default::default() + }; + let mut metadata = MetadataMap::new(); + metadata.insert("temporal-resource-id", "existing".parse().unwrap()); + + let headers = extract_temporal_request_headers(&request as &dyn Any, Some(&metadata)); + assert!(headers.is_empty()); + } + + #[test] + fn empty_field_produces_no_header() { + let request = StartWorkflowExecutionRequest::default(); + assert_eq!(extract_resource_id(&request), None); + } + + // --- Workflow request headers (annotation extracts from workflow_id or resource_id) --- + + #[test] + fn start_workflow_execution() { + let request = StartWorkflowExecutionRequest { + workflow_id: "my-wf".to_string(), + ..Default::default() + }; + assert_eq!(extract_resource_id(&request), Some("workflow:my-wf".into())); + } + + #[test] + fn signal_workflow_execution() { + let request = SignalWorkflowExecutionRequest { + workflow_execution: Some( + crate::protos::temporal::api::common::v1::WorkflowExecution { + workflow_id: "sig-wf".to_string(), + ..Default::default() + }, + ), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:sig-wf".into()) + ); + } + + // --- Workflow task requests (resource_id field, annotation: "{resource_id}") --- + // The SDK sets resource_id with prefix already applied (e.g., "workflow:wf-id"), + // and the annotation passes it through as-is. + + #[test] + fn respond_workflow_task_completed() { + let request = RespondWorkflowTaskCompletedRequest { + resource_id: "workflow:wf-completed".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:wf-completed".into()) + ); + } + + #[test] + fn respond_workflow_task_failed() { + let request = RespondWorkflowTaskFailedRequest { + resource_id: "workflow:wf-456".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:wf-456".into()) + ); + } + + // --- Activity task requests (resource_id field, annotation: "{resource_id}") --- + + #[test] + fn respond_activity_task_completed() { + let request = RespondActivityTaskCompletedRequest { + resource_id: "workflow:act-wf".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:act-wf".into()) + ); + } + + #[test] + fn respond_activity_task_completed_standalone() { + let request = RespondActivityTaskCompletedRequest { + resource_id: "activity:standalone-act".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("activity:standalone-act".into()) + ); + } + + #[test] + fn respond_activity_task_failed() { + let request = RespondActivityTaskFailedRequest { + resource_id: "workflow:fail-wf".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:fail-wf".into()) + ); + } + + #[test] + fn respond_activity_task_canceled() { + let request = RespondActivityTaskCanceledRequest { + resource_id: "workflow:cancel-wf".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:cancel-wf".into()) + ); + } + + #[test] + fn respond_activity_task_completed_by_id() { + let request = RespondActivityTaskCompletedByIdRequest { + resource_id: "workflow:byid-wf".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:byid-wf".into()) + ); + } + + #[test] + fn respond_activity_task_failed_by_id() { + let request = RespondActivityTaskFailedByIdRequest { + resource_id: "activity:byid-act".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("activity:byid-act".into()) + ); + } + + #[test] + fn respond_activity_task_canceled_by_id() { + let request = RespondActivityTaskCanceledByIdRequest { + resource_id: "workflow:cancel-byid-wf".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:cancel-byid-wf".into()) + ); + } + + // --- Heartbeat requests --- + + #[test] + fn record_activity_task_heartbeat() { + let request = RecordActivityTaskHeartbeatRequest { + resource_id: "workflow:hb-wf".to_string(), + ..Default::default() + }; + assert_eq!(extract_resource_id(&request), Some("workflow:hb-wf".into())); + } + + #[test] + fn record_activity_task_heartbeat_by_id() { + let request = RecordActivityTaskHeartbeatByIdRequest { + resource_id: "activity:hb-act".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("activity:hb-act".into()) + ); + } + + // --- Worker heartbeat (annotation: "{resource_id}", SDK prefixes with "worker:") --- + + #[test] + fn record_worker_heartbeat() { + let request = RecordWorkerHeartbeatRequest { + resource_id: "worker:some-grouping-key".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("worker:some-grouping-key".into()) + ); + } + + // --- Batch operation (annotation: "{resource_id}", SDK prefixes with "workflow:") --- + + #[test] + fn execute_multi_operation() { + let request = ExecuteMultiOperationRequest { + resource_id: "workflow:multi-op-wf".to_string(), + ..Default::default() + }; + assert_eq!( + extract_resource_id(&request), + Some("workflow:multi-op-wf".into()) + ); + } + + // --- Empty resource_id produces no header --- + + #[test] + fn empty_resource_id_produces_no_header() { + let request = RespondActivityTaskCompletedRequest::default(); + assert_eq!(extract_resource_id(&request), None); + } +} diff --git a/crates/sdk-core/src/core_tests/activity_tasks.rs b/crates/sdk-core/src/core_tests/activity_tasks.rs index 433759c90..fd0085c28 100644 --- a/crates/sdk-core/src/core_tests/activity_tasks.rs +++ b/crates/sdk-core/src/core_tests/activity_tasks.rs @@ -90,7 +90,7 @@ async fn max_activities_respected() { .returning(move |_, _| Ok(tasks.pop_front().unwrap())); mock_client .expect_complete_activity_task() - .returning(|_, _| Ok(RespondActivityTaskCompletedResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); let worker = Worker::new_test( test_worker_cfg() @@ -140,7 +140,7 @@ async fn heartbeats_report_cancels_only_once() { mock_client .expect_record_activity_heartbeat() .times(2) - .returning(|_, _| { + .returning(|_, _, _| { Ok(RecordActivityTaskHeartbeatResponse { cancel_requested: true, activity_paused: false, @@ -150,11 +150,11 @@ async fn heartbeats_report_cancels_only_once() { mock_client .expect_complete_activity_task() .times(1) - .returning(|_, _| Ok(RespondActivityTaskCompletedResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); mock_client .expect_cancel_activity_task() .times(1) - .returning(|_, _| Ok(RespondActivityTaskCanceledResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCanceledResponse::default())); let core = mock_worker(MocksHolder::from_client_with_activities( mock_client, @@ -266,7 +266,7 @@ async fn activity_cancel_interrupts_poll() { mock_client .expect_record_activity_heartbeat() .times(1) - .returning(|_, _| { + .returning(|_, _, _| { async { Ok(RecordActivityTaskHeartbeatResponse { cancel_requested: true, @@ -279,7 +279,7 @@ async fn activity_cancel_interrupts_poll() { mock_client .expect_complete_activity_task() .times(1) - .returning(|_, _| async { Ok(RespondActivityTaskCompletedResponse::default()) }.boxed()); + .returning(|_, _, _| async { Ok(RespondActivityTaskCompletedResponse::default()) }.boxed()); let mw = MockWorkerInputs { act_poller: Some(Box::from(mock_poller)), @@ -377,10 +377,10 @@ async fn many_concurrent_heartbeat_cancels() { .returning(move |_, _| poll_resps.pop_front().unwrap()); mock_client .expect_cancel_activity_task() - .returning(move |_, _| async move { Ok(Default::default()) }.boxed()); + .returning(move |_, _, _| async move { Ok(Default::default()) }.boxed()); mock_client .expect_record_activity_heartbeat() - .returning(move |tt, _| { + .returning(move |tt, _, _| { let calls = match calls_map.entry(tt) { Entry::Occupied(mut e) => { *e.get_mut() += 1; @@ -609,7 +609,7 @@ async fn can_heartbeat_acts_during_shutdown() { mock_client .expect_record_activity_heartbeat() .times(1) - .returning(|_, _| { + .returning(|_, _, _| { Ok(RecordActivityTaskHeartbeatResponse { cancel_requested: false, activity_paused: false, @@ -619,7 +619,7 @@ async fn can_heartbeat_acts_during_shutdown() { mock_client .expect_complete_activity_task() .times(1) - .returning(|_, _| Ok(RespondActivityTaskCompletedResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); let core = mock_worker(MocksHolder::from_client_with_activities( mock_client, @@ -663,7 +663,7 @@ async fn complete_act_with_fail_flushes_heartbeat() { .expect_record_activity_heartbeat() // Two times b/c we always record the first heartbeat, and we'll flush the last .times(2) - .returning_st(move |_, payload| { + .returning_st(move |_, payload, _| { *lsp.borrow_mut() = payload; Ok(RecordActivityTaskHeartbeatResponse { cancel_requested: false, @@ -674,7 +674,7 @@ async fn complete_act_with_fail_flushes_heartbeat() { mock_client .expect_fail_activity_task() .times(1) - .returning(|_, _| Ok(RespondActivityTaskFailedResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskFailedResponse::default())); let core = mock_worker(MocksHolder::from_client_with_activities( mock_client, @@ -745,7 +745,7 @@ async fn max_worker_acts_per_second_respected() { }); mock_client .expect_complete_activity_task() - .returning(|_, _| Ok(RespondActivityTaskCompletedResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); let cfg = test_worker_cfg() .activity_task_poller_behavior(PollerBehavior::SimpleMaximum(1_usize)) @@ -910,7 +910,7 @@ async fn activity_tasks_from_completion_are_delivered() { }); mock.expect_complete_activity_task() .times(3) - .returning(|_, _| Ok(RespondActivityTaskCompletedResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); let act_tasks: Vec> = vec![]; let mut mh = MockPollCfg::from_resp_batches(wfid, t, [1], mock); mh.enforce_correct_number_of_polls = true; @@ -987,7 +987,7 @@ async fn retryable_net_error_exhaustion_is_nonfatal() { mock_client .expect_complete_activity_task() .times(1) - .returning(|_, _| Err(tonic::Status::internal("retryable error"))); + .returning(|_, _, _| Err(tonic::Status::internal("retryable error"))); let core = mock_worker(MocksHolder::from_client_with_activities( mock_client, @@ -1059,7 +1059,7 @@ async fn graceful_shutdown(#[values(true, false)] at_max_outstanding: bool) { mock_client .expect_fail_activity_task() .times(3) - .returning(|_, _| Ok(Default::default())); + .returning(|_, _, _| Ok(Default::default())); let max_outstanding = if at_max_outstanding { 3_usize } else { 100 }; let mw = MockWorkerInputs { @@ -1136,7 +1136,7 @@ async fn activities_must_be_flushed_to_server_on_shutdown(#[values(true, false)] mock_client .expect_complete_activity_task() .times(1) - .returning(|_, _| { + .returning(|_, _, _| { async { // We need some artificial delay here and there's nothing meaningful to sync with tokio::time::sleep(Duration::from_millis(100)).await; @@ -1184,7 +1184,7 @@ async fn heartbeat_response_can_be_paused() { mock_client .expect_record_activity_heartbeat() .times(1) - .returning(|_, _| { + .returning(|_, _, _| { Ok(RecordActivityTaskHeartbeatResponse { cancel_requested: false, activity_paused: true, @@ -1195,7 +1195,7 @@ async fn heartbeat_response_can_be_paused() { mock_client .expect_record_activity_heartbeat() .times(1) - .returning(|_, _| { + .returning(|_, _, _| { Ok(RecordActivityTaskHeartbeatResponse { cancel_requested: true, activity_paused: false, @@ -1206,7 +1206,7 @@ async fn heartbeat_response_can_be_paused() { mock_client .expect_record_activity_heartbeat() .times(1) - .returning(|_, _| { + .returning(|_, _, _| { Ok(RecordActivityTaskHeartbeatResponse { cancel_requested: true, activity_paused: true, @@ -1216,7 +1216,7 @@ async fn heartbeat_response_can_be_paused() { mock_client .expect_cancel_activity_task() .times(3) - .returning(|_, _| Ok(RespondActivityTaskCanceledResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCanceledResponse::default())); let core = mock_worker(MocksHolder::from_client_with_activities( mock_client, @@ -1340,3 +1340,181 @@ async fn heartbeat_response_can_be_paused() { core.drain_activity_poller_and_shutdown().await; } + +#[tokio::test] +async fn activity_completion_sets_workflow_resource_id() { + let mut mock_client = mock_worker_client(); + mock_client + .expect_complete_activity_task() + .withf(|_, _, resource_id| resource_id == "workflow:test-wf-id") + .times(1) + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); + + let core = mock_worker(MocksHolder::from_client_with_activities( + mock_client, + [PollActivityTaskQueueResponse { + task_token: vec![1], + activity_id: "act1".to_string(), + workflow_execution: Some( + temporalio_common::protos::temporal::api::common::v1::WorkflowExecution { + workflow_id: "test-wf-id".to_string(), + run_id: "run-1".to_string(), + }, + ), + ..Default::default() + } + .into()], + )); + + let act = core.poll_activity_task().await.unwrap(); + core.complete_activity_task(ActivityTaskCompletion { + task_token: act.task_token, + result: Some(ActivityExecutionResult::ok(vec![1].into())), + }) + .await + .unwrap(); + core.drain_activity_poller_and_shutdown().await; +} + +#[tokio::test] +async fn activity_completion_sets_activity_resource_id_for_standalone() { + let mut mock_client = mock_worker_client(); + mock_client + .expect_complete_activity_task() + .withf(|_, _, resource_id| resource_id == "activity:standalone-act") + .times(1) + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); + + let core = mock_worker(MocksHolder::from_client_with_activities( + mock_client, + [PollActivityTaskQueueResponse { + task_token: vec![1], + activity_id: "standalone-act".to_string(), + ..Default::default() + } + .into()], + )); + + let act = core.poll_activity_task().await.unwrap(); + core.complete_activity_task(ActivityTaskCompletion { + task_token: act.task_token, + result: Some(ActivityExecutionResult::ok(vec![1].into())), + }) + .await + .unwrap(); + core.drain_activity_poller_and_shutdown().await; +} + +#[tokio::test] +async fn activity_failure_sets_resource_id() { + let mut mock_client = mock_worker_client(); + mock_client + .expect_fail_activity_task() + .withf(|_, _, resource_id| resource_id == "workflow:fail-wf") + .times(1) + .returning(|_, _, _| Ok(RespondActivityTaskFailedResponse::default())); + + let core = mock_worker(MocksHolder::from_client_with_activities( + mock_client, + [PollActivityTaskQueueResponse { + task_token: vec![1], + activity_id: "act1".to_string(), + workflow_execution: Some( + temporalio_common::protos::temporal::api::common::v1::WorkflowExecution { + workflow_id: "fail-wf".to_string(), + ..Default::default() + }, + ), + ..Default::default() + } + .into()], + )); + + let act = core.poll_activity_task().await.unwrap(); + core.complete_activity_task(ActivityTaskCompletion { + task_token: act.task_token, + result: Some(ActivityExecutionResult::fail("boom".into())), + }) + .await + .unwrap(); + core.drain_activity_poller_and_shutdown().await; +} + +#[tokio::test] +async fn activity_heartbeat_sets_resource_id() { + let mut mock_client = mock_worker_client(); + mock_client + .expect_record_activity_heartbeat() + .withf(|_, _, resource_id| resource_id == "workflow:hb-wf") + .times(1) + .returning(|_, _, _| Ok(RecordActivityTaskHeartbeatResponse::default())); + mock_client + .expect_complete_activity_task() + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); + + let core = mock_worker(MocksHolder::from_client_with_activities( + mock_client, + [PollActivityTaskQueueResponse { + task_token: vec![1], + activity_id: "act1".to_string(), + heartbeat_timeout: Some(prost_dur!(from_millis(1))), + workflow_execution: Some( + temporalio_common::protos::temporal::api::common::v1::WorkflowExecution { + workflow_id: "hb-wf".to_string(), + ..Default::default() + }, + ), + ..Default::default() + } + .into()], + )); + + let act = core.poll_activity_task().await.unwrap(); + core.record_activity_heartbeat(ActivityHeartbeat { + task_token: act.task_token.clone(), + details: vec![vec![1_u8].into()], + }); + sleep(Duration::from_millis(10)).await; + core.complete_activity_task(ActivityTaskCompletion { + task_token: act.task_token, + result: Some(ActivityExecutionResult::ok(vec![1].into())), + }) + .await + .unwrap(); + core.drain_activity_poller_and_shutdown().await; +} + +#[tokio::test] +async fn activity_cancellation_sets_resource_id() { + let mut mock_client = mock_worker_client(); + mock_client + .expect_cancel_activity_task() + .withf(|_, _, resource_id| resource_id == "workflow:cancel-wf") + .times(1) + .returning(|_, _, _| Ok(RespondActivityTaskCanceledResponse::default())); + + let core = mock_worker(MocksHolder::from_client_with_activities( + mock_client, + [PollActivityTaskQueueResponse { + task_token: vec![1], + activity_id: "act1".to_string(), + workflow_execution: Some( + temporalio_common::protos::temporal::api::common::v1::WorkflowExecution { + workflow_id: "cancel-wf".to_string(), + ..Default::default() + }, + ), + ..Default::default() + } + .into()], + )); + + let act = core.poll_activity_task().await.unwrap(); + core.complete_activity_task(ActivityTaskCompletion { + task_token: act.task_token, + result: Some(ActivityExecutionResult::cancel_from_details(None)), + }) + .await + .unwrap(); + core.drain_activity_poller_and_shutdown().await; +} diff --git a/crates/sdk-core/src/core_tests/workers.rs b/crates/sdk-core/src/core_tests/workers.rs index 8523dcafd..91085beeb 100644 --- a/crates/sdk-core/src/core_tests/workers.rs +++ b/crates/sdk-core/src/core_tests/workers.rs @@ -496,7 +496,7 @@ async fn test_task_type_combinations_unified( if enable_local_activities || enable_remote_activities { client .expect_complete_activity_task() - .returning(|_, _| Ok(RespondActivityTaskCompletedResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); } if enable_nexus { client diff --git a/crates/sdk-core/src/core_tests/workflow_tasks.rs b/crates/sdk-core/src/core_tests/workflow_tasks.rs index d8c161586..e07f179fa 100644 --- a/crates/sdk-core/src/core_tests/workflow_tasks.rs +++ b/crates/sdk-core/src/core_tests/workflow_tasks.rs @@ -1235,7 +1235,7 @@ async fn fail_wft_then_recover() { ); mh.num_expected_fails = 1; mh.expect_fail_wft_matcher = - Box::new(|_, cause, _| matches!(cause, WorkflowTaskFailedCause::NonDeterministicError)); + Box::new(|_, cause, _, _| matches!(cause, WorkflowTaskFailedCause::NonDeterministicError)); let mut mock = build_mock_pollers(mh); mock.worker_cfg(|wc| { wc.max_cached_workflows = 2; @@ -1293,7 +1293,7 @@ async fn default_wft_fail_cause_is_worker_unhandled() { mock_worker_client(), ); mh.num_expected_fails = 1; - mh.expect_fail_wft_matcher = Box::new(|_, cause, _| { + mh.expect_fail_wft_matcher = Box::new(|_, cause, _, _| { matches!( cause, WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure @@ -1334,7 +1334,7 @@ async fn poll_response_triggers_wf_error() { ); // Fail wft will be called when auto-failing. mh.num_expected_fails = 1; - mh.expect_fail_wft_matcher = Box::new(move |_, cause, _| { + mh.expect_fail_wft_matcher = Box::new(move |_, cause, _, _| { matches!(cause, WorkflowTaskFailedCause::NonDeterministicError) }); let mock = build_mock_pollers(mh); @@ -2317,7 +2317,7 @@ async fn ensure_fetching_fail_during_complete_sends_task_failure() { }) .times(1); mock.expect_fail_workflow_task() - .returning(|_, _, _| Ok(Default::default())) + .returning(|_, _, _, _| Ok(Default::default())) .times(1); let mut mock = single_hist_mock_sg(wfid, t, [ResponseType::Raw(first_poll)], mock, true); diff --git a/crates/sdk-core/src/replay/mod.rs b/crates/sdk-core/src/replay/mod.rs index 7a8ff764e..60e549694 100644 --- a/crates/sdk-core/src/replay/mod.rs +++ b/crates/sdk-core/src/replay/mod.rs @@ -113,7 +113,7 @@ where }); client .expect_fail_workflow_task() - .returning(move |_, _, _| { + .returning(move |_, _, _, _| { hist_allow_tx.send("Failed".to_string()).unwrap(); async move { Ok(RespondWorkflowTaskFailedResponse::default()) }.boxed() }); diff --git a/crates/sdk-core/src/test_help/integ_helpers.rs b/crates/sdk-core/src/test_help/integ_helpers.rs index 829615a45..8131b6cda 100644 --- a/crates/sdk-core/src/test_help/integ_helpers.rs +++ b/crates/sdk-core/src/test_help/integ_helpers.rs @@ -514,7 +514,7 @@ pub struct MockPollCfg { pub mock_client: MockWorkerClient, /// All calls to fail WFTs must match this predicate pub expect_fail_wft_matcher: - Box) -> bool + Send>, + Box, &String) -> bool + Send>, /// All calls to legacy query responses must match this predicate pub expect_legacy_query_matcher: Box bool + Send>, pub completion_mock_fn: Option>, @@ -541,7 +541,7 @@ impl MockPollCfg { num_expected_fails, num_expected_legacy_query_resps: 0, mock_client: mock_worker_client(), - expect_fail_wft_matcher: Box::new(|_, _, _| true), + expect_fail_wft_matcher: Box::new(|_, _, _, _| true), expect_legacy_query_matcher: Box::new(|_, _| true), completion_mock_fn: None, num_expected_completions: None, @@ -581,7 +581,7 @@ impl MockPollCfg { num_expected_fails: 0, num_expected_legacy_query_resps: 0, mock_client, - expect_fail_wft_matcher: Box::new(|_, _, _| true), + expect_fail_wft_matcher: Box::new(|_, _, _, _| true), expect_legacy_query_matcher: Box::new(|_, _| true), completion_mock_fn: None, num_expected_completions: None, @@ -819,7 +819,7 @@ pub fn build_mock_pollers(mut cfg: MockPollCfg) -> MocksHolder { .expect_fail_workflow_task() .withf(cfg.expect_fail_wft_matcher) .times::(cfg.num_expected_fails.into()) - .returning(move |tt, _, _| { + .returning(move |tt, _, _, _| { outstanding.release_token(&tt); Ok(Default::default()) }); diff --git a/crates/sdk-core/src/worker/activities.rs b/crates/sdk-core/src/worker/activities.rs index 71a583ece..017ba4045 100644 --- a/crates/sdk-core/src/worker/activities.rs +++ b/crates/sdk-core/src/worker/activities.rs @@ -85,9 +85,9 @@ impl PendingActivityCancel { /// Contains details that core wants to store while an activity is running. #[derive(Debug)] struct InFlightActInfo { + activity_id: String, activity_type: String, workflow_type: String, - /// Only kept for logging reasons workflow_id: String, /// Only kept for logging reasons workflow_run_id: String, @@ -122,6 +122,7 @@ impl RemoteInFlightActInfo { let wec = poll_resp.workflow_execution.clone().unwrap_or_default(); Self { base: InFlightActInfo { + activity_id: poll_resp.activity_id.clone(), activity_type: poll_resp.activity_type.clone().unwrap_or_default().name, workflow_type: poll_resp.workflow_type.clone().unwrap_or_default().name, workflow_id: wec.workflow_id, @@ -137,6 +138,16 @@ impl RemoteInFlightActInfo { _permit: permit, } } + + fn resource_id(&self) -> String { + if !self.base.workflow_id.is_empty() { + format!("workflow:{}", self.base.workflow_id) + } else if !self.base.activity_id.is_empty() { + format!("activity:{}", self.base.activity_id) + } else { + String::new() + } + } } pub(crate) struct WorkerActivityTasks { @@ -321,11 +332,12 @@ impl WorkerActivityTasks { client: &dyn WorkerClient, ) { if let Some((_, act_info)) = self.outstanding_activity_tasks.remove(&task_token) { + let resource_id = act_info.resource_id(); let act_metrics = self.metrics.with_new_attrs([ activity_type(act_info.base.activity_type), workflow_type(act_info.base.workflow_type), ]); - Span::current().record("workflow_id", act_info.base.workflow_id); + Span::current().record("workflow_id", act_info.base.workflow_id.as_str()); Span::current().record("run_id", act_info.base.workflow_run_id); act_metrics.act_execution_latency(act_info.base.start_time.elapsed()); let known_not_found = act_info.known_not_found; @@ -356,7 +368,11 @@ impl WorkerActivityTasks { act_metrics.act_execution_succeeded(sched_time); } client - .complete_activity_task(task_token.clone(), result.map(Into::into)) + .complete_activity_task( + task_token.clone(), + result.map(Into::into), + resource_id, + ) .await .err() } @@ -365,7 +381,7 @@ impl WorkerActivityTasks { act_metrics.act_execution_failed(); } client - .fail_activity_task(task_token.clone(), failure) + .fail_activity_task(task_token.clone(), failure, resource_id) .await .err() } @@ -381,6 +397,7 @@ impl WorkerActivityTasks { .fail_activity_task( task_token.clone(), Some(worker_shutdown_failure()), + resource_id, ) .await .err() @@ -400,7 +417,7 @@ impl WorkerActivityTasks { None }; client - .cancel_activity_task(task_token.clone(), details) + .cancel_activity_task(task_token.clone(), details, resource_id) .await .err() } @@ -456,8 +473,12 @@ impl WorkerActivityTasks { }; let throttle_interval = std::cmp::min(throttle_interval, self.max_heartbeat_throttle_interval); - self.heartbeat_manager - .record(details, throttle_interval, at_info.timeout_resetter.clone()) + self.heartbeat_manager.record( + details, + throttle_interval, + at_info.timeout_resetter.clone(), + at_info.resource_id(), + ) } /// Returns a handle that the workflows management side can use to interact with this manager @@ -764,7 +785,7 @@ mod tests { mock_client .expect_complete_activity_task() .times(2) - .returning(|_, _| Ok(Default::default())); + .returning(|_, _, _| Ok(Default::default())); let mock_client = Arc::new(mock_client); let sem = fixed_size_permit_dealer(10); let shutdown_token = CancellationToken::new(); @@ -935,7 +956,7 @@ mod tests { mock_client .expect_record_activity_heartbeat() .times(2) - .returning(|_, _| Ok(Default::default())); + .returning(|_, _, _| Ok(Default::default())); let mock_client = Arc::new(mock_client); let sem = fixed_size_permit_dealer(1); let shutdown_token = CancellationToken::new(); diff --git a/crates/sdk-core/src/worker/activities/activity_heartbeat_manager.rs b/crates/sdk-core/src/worker/activities/activity_heartbeat_manager.rs index 9950c6145..595e17562 100644 --- a/crates/sdk-core/src/worker/activities/activity_heartbeat_manager.rs +++ b/crates/sdk-core/src/worker/activities/activity_heartbeat_manager.rs @@ -54,6 +54,7 @@ struct ValidActivityHeartbeat { details: Vec, throttle_interval: Duration, timeout_resetter: Option>, + resource_id: String, } #[derive(Debug)] @@ -64,6 +65,7 @@ enum HeartbeatExecutorAction { Report { task_token: TaskToken, details: Vec, + resource_id: String, }, } @@ -141,9 +143,9 @@ impl ActivityHeartbeatManager { }, }; } - HeartbeatExecutorAction::Report { task_token: tt, details } => { + HeartbeatExecutorAction::Report { task_token: tt, details, resource_id} => { match sg - .record_activity_heartbeat(tt.clone(), details.into_payloads()) + .record_activity_heartbeat(tt.clone(), details.into_payloads(), resource_id) .await { Ok(RecordActivityTaskHeartbeatResponse { @@ -217,6 +219,7 @@ impl ActivityHeartbeatManager { hb: ActivityHeartbeat, throttle_interval: Duration, timeout_resetter: Option>, + resource_id: String, ) -> Result<(), ActivityHeartbeatError> { self.heartbeat_tx .send(HeartbeatAction::SendHeartbeat(ValidActivityHeartbeat { @@ -224,6 +227,7 @@ impl ActivityHeartbeatManager { details: hb.details, throttle_interval, timeout_resetter, + resource_id, })) .expect("Receive half of the heartbeats event channel must not be dropped"); @@ -271,6 +275,7 @@ struct ActivityHeartbeatState { throttle_interval: Duration, throttled_cancellation_token: Option, timeout_resetter: Option>, + resource_id: String, } impl ActivityHeartbeatState { @@ -327,11 +332,13 @@ impl HeartbeatStreamState { is_record_in_flight: true, throttled_cancellation_token: None, timeout_resetter: hb.timeout_resetter, + resource_id: hb.resource_id.clone(), }; e.insert(state); Some(HeartbeatExecutorAction::Report { task_token: hb.task_token, details: hb.details, + resource_id: hb.resource_id, }) } Entry::Occupied(mut o) => { @@ -380,6 +387,7 @@ impl HeartbeatStreamState { Some(HeartbeatExecutorAction::Report { task_token: tt, details, + resource_id: state.resource_id.clone(), }) } else { // Nothing to report, forget this task token @@ -412,6 +420,7 @@ impl HeartbeatStreamState { return Some(HeartbeatExecutorAction::Report { task_token: tt, details: last_deets, + resource_id: state.resource_id, }); } else if state.is_record_in_flight { self.tt_needs_flush.insert(tt, on_complete); @@ -442,7 +451,7 @@ mod test { let mut mock_client = mock_worker_client(); mock_client .expect_record_activity_heartbeat() - .returning(|_, _| Ok(RecordActivityTaskHeartbeatResponse::default())) + .returning(|_, _, _| Ok(RecordActivityTaskHeartbeatResponse::default())) .times(2); let (cancel_tx, _cancel_rx) = unbounded_channel(); let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); @@ -464,7 +473,7 @@ mod test { let mut mock_client = mock_worker_client(); mock_client .expect_record_activity_heartbeat() - .returning(|_, _| Ok(RecordActivityTaskHeartbeatResponse::default())) + .returning(|_, _, _| Ok(RecordActivityTaskHeartbeatResponse::default())) .times(3); let (cancel_tx, _cancel_rx) = unbounded_channel(); let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); @@ -483,7 +492,7 @@ mod test { let mut mock_client = mock_worker_client(); mock_client .expect_record_activity_heartbeat() - .returning(|_, _| Ok(RecordActivityTaskHeartbeatResponse::default())) + .returning(|_, _, _| Ok(RecordActivityTaskHeartbeatResponse::default())) .times(1); let (cancel_tx, _cancel_rx) = unbounded_channel(); let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); @@ -503,7 +512,7 @@ mod test { let mut mock_client = mock_worker_client(); mock_client .expect_record_activity_heartbeat() - .returning(|_, _| Ok(RecordActivityTaskHeartbeatResponse::default())) + .returning(|_, _, _| Ok(RecordActivityTaskHeartbeatResponse::default())) .times(2); let (cancel_tx, _cancel_rx) = unbounded_channel(); let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); @@ -521,7 +530,7 @@ mod test { let mut mock_client = mock_worker_client(); mock_client .expect_record_activity_heartbeat() - .returning(|_, _| Ok(RecordActivityTaskHeartbeatResponse::default())) + .returning(|_, _, _| Ok(RecordActivityTaskHeartbeatResponse::default())) .times(2); let (cancel_tx, _cancel_rx) = unbounded_channel(); let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); @@ -542,7 +551,7 @@ mod test { let mut mock_client = mock_worker_client(); mock_client .expect_record_activity_heartbeat() - .returning(|_, _| Ok(RecordActivityTaskHeartbeatResponse::default())) + .returning(|_, _, _| Ok(RecordActivityTaskHeartbeatResponse::default())) .times(1); let (cancel_tx, _cancel_rx) = unbounded_channel(); let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); @@ -558,7 +567,7 @@ mod test { // Should only expect 1 heartbeat call, not 2 (the second would be from evict flushing) mock_client .expect_record_activity_heartbeat() - .returning(|_, _| Ok(RecordActivityTaskHeartbeatResponse::default())) + .returning(|_, _, _| Ok(RecordActivityTaskHeartbeatResponse::default())) .times(1); let (cancel_tx, _cancel_rx) = unbounded_channel(); let hm = ActivityHeartbeatManager::new(Arc::new(mock_client), cancel_tx); @@ -601,6 +610,7 @@ mod test { // Mimic the same delay we would apply in activity task manager throttle_interval, None, + String::new(), ) .expect("hearbeat recording should not fail"); } diff --git a/crates/sdk-core/src/worker/client.rs b/crates/sdk-core/src/worker/client.rs index 206da074e..fc12d1f03 100644 --- a/crates/sdk-core/src/worker/client.rs +++ b/crates/sdk-core/src/worker/client.rs @@ -166,6 +166,7 @@ pub trait WorkerClient: Sync + Send { &self, task_token: TaskToken, result: Option, + resource_id: String, ) -> Result; /// Complete a Nexus task async fn complete_nexus_task( @@ -178,18 +179,21 @@ pub trait WorkerClient: Sync + Send { &self, task_token: TaskToken, details: Option, + resource_id: String, ) -> Result; /// Cancel an activity task async fn cancel_activity_task( &self, task_token: TaskToken, details: Option, + resource_id: String, ) -> Result; /// Fail an activity task async fn fail_activity_task( &self, task_token: TaskToken, failure: Option, + resource_id: String, ) -> Result; /// Fail a workflow task async fn fail_workflow_task( @@ -197,6 +201,7 @@ pub trait WorkerClient: Sync + Send { task_token: TaskToken, cause: WorkflowTaskFailedCause, failure: Option, + workflow_id: String, ) -> Result; /// Fail a Nexus task async fn fail_nexus_task( @@ -414,6 +419,7 @@ impl WorkerClient for WorkerClientBag { force_create_new_workflow_task: request.force_create_new_workflow_task, worker_version_stamp: self.worker_version_stamp(), binary_checksum: self.binary_checksum(), + resource_id: format!("workflow:{}", request.workflow_id), query_results: request .query_responses .into_iter() @@ -441,7 +447,6 @@ impl WorkerClient for WorkerClientBag { deployment: None, versioning_behavior: request.versioning_behavior.into(), deployment_options: self.deployment_options(), - resource_id: Default::default(), }; Ok(self .connection @@ -455,6 +460,7 @@ impl WorkerClient for WorkerClientBag { &self, task_token: TaskToken, result: Option, + resource_id: String, ) -> Result { Ok(self .connection @@ -470,7 +476,7 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - resource_id: Default::default(), + resource_id, } .into_request(), ) @@ -503,6 +509,7 @@ impl WorkerClient for WorkerClientBag { &self, task_token: TaskToken, details: Option, + resource_id: String, ) -> Result { Ok(self .connection @@ -513,7 +520,7 @@ impl WorkerClient for WorkerClientBag { details, identity: self.identity(), namespace: self.namespace.clone(), - resource_id: Default::default(), + resource_id, } .into_request(), ) @@ -525,6 +532,7 @@ impl WorkerClient for WorkerClientBag { &self, task_token: TaskToken, details: Option, + resource_id: String, ) -> Result { Ok(self .connection @@ -540,7 +548,7 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - resource_id: Default::default(), + resource_id, } .into_request(), ) @@ -552,6 +560,7 @@ impl WorkerClient for WorkerClientBag { &self, task_token: TaskToken, failure: Option, + resource_id: String, ) -> Result { Ok(self .connection @@ -569,7 +578,7 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - resource_id: Default::default(), + resource_id, } .into_request(), ) @@ -582,6 +591,7 @@ impl WorkerClient for WorkerClientBag { task_token: TaskToken, cause: WorkflowTaskFailedCause, failure: Option, + workflow_id: String, ) -> Result { #[allow(deprecated)] // want to list all fields explicitly let request = RespondWorkflowTaskFailedRequest { @@ -596,7 +606,7 @@ impl WorkerClient for WorkerClientBag { // Will never be set, deprecated. deployment: None, deployment_options: self.deployment_options(), - resource_id: Default::default(), + resource_id: format!("workflow:{}", workflow_id), }; Ok(self .connection @@ -746,11 +756,12 @@ impl WorkerClient for WorkerClientBag { namespace: String, worker_heartbeat: Vec, ) -> Result { + let connection = self.connection.inner_cow(); let request = RecordWorkerHeartbeatRequest { namespace, identity: self.identity(), worker_heartbeat, - resource_id: Default::default(), + resource_id: format!("worker:{}", connection.worker_grouping_key()), }; Ok(self .connection @@ -875,6 +886,8 @@ pub struct WorkflowTaskCompletion { pub metering_metadata: MeteringMetadata, /// Versioning behavior of the workflow, if any. pub versioning_behavior: VersioningBehavior, + /// Workflow ID + pub workflow_id: String, } #[derive(Clone, Default)] diff --git a/crates/sdk-core/src/worker/client/mocks.rs b/crates/sdk-core/src/worker/client/mocks.rs index 983ef139e..d29394178 100644 --- a/crates/sdk-core/src/worker/client/mocks.rs +++ b/crates/sdk-core/src/worker/client/mocks.rs @@ -94,6 +94,7 @@ mockall::mock! { &self, task_token: TaskToken, result: Option, + resource_id: String, ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; @@ -108,6 +109,7 @@ mockall::mock! { &self, task_token: TaskToken, details: Option, + resource_id: String, ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; @@ -115,6 +117,7 @@ mockall::mock! { &self, task_token: TaskToken, failure: Option, + resource_id: String, ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; @@ -123,6 +126,7 @@ mockall::mock! { task_token: TaskToken, cause: WorkflowTaskFailedCause, failure: Option, + workflow_id: String, ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; @@ -137,6 +141,7 @@ mockall::mock! { &self, task_token: TaskToken, details: Option, + resource_id: String, ) -> impl Future> + Send + 'b where 'a: 'b, Self: 'b; diff --git a/crates/sdk-core/src/worker/workflow/managed_run.rs b/crates/sdk-core/src/worker/workflow/managed_run.rs index d42dcd693..66ecd7556 100644 --- a/crates/sdk-core/src/worker/workflow/managed_run.rs +++ b/crates/sdk-core/src/worker/workflow/managed_run.rs @@ -411,6 +411,7 @@ impl ManagedRun { tt, WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure, Failure::application_failure(reason, true).into(), + self.workflow_id().to_string(), )) } else { ActivationCompleteOutcome::DoNothing @@ -439,6 +440,7 @@ impl ManagedRun { result: Box::new(qr), }, metrics: self.metrics.clone(), + workflow_id: self.workflow_id().to_owned(), }), resp_chan, ); @@ -641,7 +643,10 @@ impl ManagedRun { }); } else { ActivationCompleteOutcome::ReportWFTFail(FailedActivationWFTReport::Report( - tt, cause, failure, + tt, + cause, + failure, + self.workflow_id().to_string(), )) } } else { @@ -1079,6 +1084,7 @@ impl ManagedRun { data: CompletionDataForWFT, due_to_heartbeat_timeout: bool, ) -> FulfillableActivationComplete { + let wf_id = self.workflow_id().to_owned(); let mut machines_wft_response = self.wfm.prepare_for_wft_response(); if data.activation_was_eviction && (machines_wft_response.commands().peek().is_some() @@ -1139,6 +1145,7 @@ impl ManagedRun { attempt, }, metrics: self.metrics.clone(), + workflow_id: wf_id, }) } else { ActivationCompleteOutcome::DoNothing @@ -1222,6 +1229,10 @@ impl ManagedRun { fn run_id(&self) -> &str { &self.wfm.machines.run_id } + + fn workflow_id(&self) -> &str { + &self.wfm.machines.workflow_id + } } // Construct a new command sequence with query responses removed, and any diff --git a/crates/sdk-core/src/worker/workflow/mod.rs b/crates/sdk-core/src/worker/workflow/mod.rs index 796a98df4..c73c442b6 100644 --- a/crates/sdk-core/src/worker/workflow/mod.rs +++ b/crates/sdk-core/src/worker/workflow/mod.rs @@ -353,6 +353,7 @@ impl Workflows { attempt, }, metrics: run_metrics, + workflow_id: wf_id, } => { let reserved_act_permits = self.reserve_activity_slots_for_outgoing_commands(commands.as_mut_slice()); @@ -384,6 +385,7 @@ impl Workflows { nonfirst_local_activity_execution_attempts, }, versioning_behavior, + workflow_id: wf_id.clone(), }; let sticky_attrs = self.sticky_attrs.clone(); // Do not return new WFT if we would not cache, because returned new WFTs are @@ -436,6 +438,7 @@ impl Workflows { task_token, WorkflowTaskFailedCause::GrpcMessageTooLarge, failure, + wf_id, ); self.handle_activation_failed(run_id, completion_time, new_outcome) .await; @@ -481,11 +484,11 @@ impl Workflows { outcome: FailedActivationWFTReport, ) -> WFTReportStatus { match outcome { - FailedActivationWFTReport::Report(tt, cause, failure) => { + FailedActivationWFTReport::Report(tt, cause, failure, workflow_id) => { warn!(run_id=%run_id, failure=?failure, "Failing workflow task"); self.handle_wft_reporting_errs(run_id, || async { self.client - .fail_workflow_task(tt, cause, failure.failure) + .fail_workflow_task(tt, cause, failure.failure, workflow_id) .await }) .await; @@ -1012,7 +1015,7 @@ struct WorkflowTaskInfo { #[derive(Debug)] enum FailedActivationWFTReport { - Report(TaskToken, WorkflowTaskFailedCause, Failure), + Report(TaskToken, WorkflowTaskFailedCause, Failure, String), ReportLegacyQueryFailure(TaskToken, Failure), } @@ -1020,6 +1023,7 @@ struct ServerCommandsWithWorkflowInfo { task_token: TaskToken, action: ActivationAction, metrics: MetricsContext, + workflow_id: String, } impl Debug for ServerCommandsWithWorkflowInfo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { diff --git a/crates/sdk-core/tests/integ_tests/worker_tests.rs b/crates/sdk-core/tests/integ_tests/worker_tests.rs index 60def06ed..b41860a05 100644 --- a/crates/sdk-core/tests/integ_tests/worker_tests.rs +++ b/crates/sdk-core/tests/integ_tests/worker_tests.rs @@ -336,7 +336,7 @@ async fn activity_tasks_from_completion_reserve_slots() { ]; mock.expect_complete_activity_task() .times(2) - .returning(|_, _| Ok(RespondActivityTaskCompletedResponse::default())); + .returning(|_, _, _| Ok(RespondActivityTaskCompletedResponse::default())); let barr: &'static Barrier = Box::leak(Box::new(Barrier::new(2))); let mut mh = MockPollCfg::from_resp_batches( wf_id, diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/determinism.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/determinism.rs index 4bb5d2e6c..0a3ec43bc 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/determinism.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/determinism.rs @@ -178,7 +178,7 @@ async fn test_panic_wf_task_rejected_properly() { // We should see one wft failure which has the default cause, since panics don't have a defined // type. mh.num_expected_fails = 1; - mh.expect_fail_wft_matcher = Box::new(|_, cause, _| { + mh.expect_fail_wft_matcher = Box::new(|_, cause, _, _| { matches!( cause, WorkflowTaskFailedCause::WorkflowWorkerUnhandledFailure @@ -238,7 +238,7 @@ async fn test_wf_task_rejected_properly_due_to_nondeterminism(#[case] use_cache: ); mh.num_expected_fails = 1; mh.expect_fail_wft_matcher = - Box::new(|_, cause, _| matches!(cause, WorkflowTaskFailedCause::NonDeterministicError)); + Box::new(|_, cause, _, _| matches!(cause, WorkflowTaskFailedCause::NonDeterministicError)); let mut worker = mock_sdk_cfg(mh, |cfg| { if use_cache { cfg.max_cached_workflows = 2; @@ -336,7 +336,7 @@ async fn activity_id_or_type_change_is_nondeterministic( mock, ); mh.num_expected_fails = 1; - mh.expect_fail_wft_matcher = Box::new(move |_, cause, f| { + mh.expect_fail_wft_matcher = Box::new(move |_, cause, f, _| { let should_contain = if id_change { "does not match activity id" } else { @@ -413,7 +413,7 @@ async fn child_wf_id_or_type_change_is_nondeterministic( mock, ); mh.num_expected_fails = 1; - mh.expect_fail_wft_matcher = Box::new(move |_, cause, f| { + mh.expect_fail_wft_matcher = Box::new(move |_, cause, f, _| { let should_contain = if id_change { "does not match child workflow id" } else { diff --git a/crates/sdk-core/tests/integ_tests/workflow_tests/timers.rs b/crates/sdk-core/tests/integ_tests/workflow_tests/timers.rs index bbad6956c..529911542 100644 --- a/crates/sdk-core/tests/integ_tests/workflow_tests/timers.rs +++ b/crates/sdk-core/tests/integ_tests/workflow_tests/timers.rs @@ -214,7 +214,7 @@ async fn mismatched_timer_ids_errors() { let t = canned_histories::single_timer("badid"); let mut mock_cfg = MockPollCfg::from_hist_builder(t); mock_cfg.num_expected_fails = 1; - mock_cfg.expect_fail_wft_matcher = Box::new(move |_, cause, f| { + mock_cfg.expect_fail_wft_matcher = Box::new(move |_, cause, f, _| { matches!(cause, WorkflowTaskFailedCause::NonDeterministicError) && matches!(f, Some(Failure {message, .. }) if message.contains("Timer fired event did not have expected timer id 1"))