diff --git a/apollo-router/src/plugins/connectors/handle_responses.rs b/apollo-router/src/plugins/connectors/handle_responses.rs index 26933107d3..7d3ccf9582 100644 --- a/apollo-router/src/plugins/connectors/handle_responses.rs +++ b/apollo-router/src/plugins/connectors/handle_responses.rs @@ -1,7 +1,10 @@ +use std::sync::Arc; + use apollo_compiler::validation::Valid; use apollo_compiler::Schema; use apollo_federation::sources::connect::ApplyTo; use apollo_federation::sources::connect::Connector; +use parking_lot::Mutex; use serde_json_bytes::ByteString; use serde_json_bytes::Value; @@ -35,7 +38,7 @@ pub(crate) enum HandleResponseError { pub(crate) async fn handle_responses( responses: Vec>, connector: &Connector, - debug: &mut Option, + debug: &Option>>, _schema: &Valid, // TODO for future apply_with_selection ) -> Result { use HandleResponseError::*; @@ -58,8 +61,8 @@ pub(crate) async fn handle_responses( if parts.status.is_success() { let Ok(json_data) = serde_json::from_slice::(body) else { - if let Some(ref mut debug) = debug { - debug.push_invalid_response(&parts, body); + if let Some(debug) = debug { + debug.lock().push_invalid_response(&parts, body); } return Err(InvalidResponseBody( "couldn't deserialize response body".into(), @@ -77,8 +80,8 @@ pub(crate) async fn handle_responses( &response_key.inputs().merge(connector.config.as_ref()), ); - if let Some(ref mut debug) = debug { - debug.push_response( + if let Some(ref debug) = debug { + debug.lock().push_response( &parts, &json_data, Some(SelectionData { @@ -169,13 +172,13 @@ pub(crate) async fn handle_responses( _ => {} }; - if let Some(ref mut debug) = debug { + if let Some(ref debug) = debug { match serde_json::from_slice(body) { Ok(json_data) => { - debug.push_response(&parts, &json_data, None); + debug.lock().push_response(&parts, &json_data, None); } Err(_) => { - debug.push_invalid_response(&parts, body); + debug.lock().push_invalid_response(&parts, body); } } } @@ -300,10 +303,9 @@ mod tests { let schema = Schema::parse_and_validate("type Query { hello: String }", "./").unwrap(); - let res = - super::handle_responses(vec![response1, response2], &connector, &mut None, &schema) - .await - .unwrap(); + let res = super::handle_responses(vec![response1, response2], &connector, &None, &schema) + .await + .unwrap(); assert_debug_snapshot!(res, @r###" Response { @@ -400,10 +402,9 @@ mod tests { ) .unwrap(); - let res = - super::handle_responses(vec![response1, response2], &connector, &mut None, &schema) - .await - .unwrap(); + let res = super::handle_responses(vec![response1, response2], &connector, &None, &schema) + .await + .unwrap(); assert_debug_snapshot!(res, @r###" Response { @@ -506,10 +507,9 @@ mod tests { ) .unwrap(); - let res = - super::handle_responses(vec![response1, response2], &connector, &mut None, &schema) - .await - .unwrap(); + let res = super::handle_responses(vec![response1, response2], &connector, &None, &schema) + .await + .unwrap(); assert_debug_snapshot!(res, @r###" Response { @@ -628,7 +628,7 @@ mod tests { let res = super::handle_responses( vec![response1, response2, response3], &connector, - &mut None, + &None, &schema, ) .await diff --git a/apollo-router/src/plugins/connectors/http_json_transport.rs b/apollo-router/src/plugins/connectors/http_json_transport.rs index fa89af160e..4501ff3c6d 100644 --- a/apollo-router/src/plugins/connectors/http_json_transport.rs +++ b/apollo-router/src/plugins/connectors/http_json_transport.rs @@ -24,6 +24,7 @@ use http::HeaderMap; use http::HeaderName; use http::HeaderValue; use lazy_static::lazy_static; +use parking_lot::Mutex; use serde_json_bytes::json; use serde_json_bytes::ByteString; use serde_json_bytes::Map; @@ -67,7 +68,7 @@ pub(crate) fn make_request( transport: &HttpJsonTransport, inputs: IndexMap, original_request: &connect::Request, - debug: &mut Option, + debug: &Option>>, ) -> Result, HttpJsonTransportError> { let uri = make_uri( transport.source_url.as_ref(), @@ -100,8 +101,8 @@ pub(crate) fn make_request( &transport.headers, ); - if let Some(ref mut debug) = debug { - debug.push_request( + if let Some(debug) = debug { + debug.lock().push_request( &request, json_body.as_ref(), transport.body.as_ref().map(|body| SelectionData { diff --git a/apollo-router/src/plugins/connectors/make_requests.rs b/apollo-router/src/plugins/connectors/make_requests.rs index c30c7688e3..607363e959 100644 --- a/apollo-router/src/plugins/connectors/make_requests.rs +++ b/apollo-router/src/plugins/connectors/make_requests.rs @@ -1,9 +1,12 @@ +use std::sync::Arc; + use apollo_compiler::collections::IndexMap; use apollo_compiler::executable::Selection; use apollo_federation::sources::connect::Connector; use apollo_federation::sources::connect::CustomConfiguration; use apollo_federation::sources::connect::EntityResolver; use itertools::Itertools; +use parking_lot::Mutex; use serde_json_bytes::json; use serde_json_bytes::ByteString; use serde_json_bytes::Map; @@ -131,7 +134,7 @@ pub(crate) enum ResponseTypeName { pub(crate) fn make_requests( request: connect::Request, connector: &Connector, - debug: &mut Option, + debug: &Option>>, ) -> Result, ResponseKey)>, MakeRequestError> { let request_params = match connector.entity_resolver { Some(EntityResolver::Explicit) => entities_from_request(&request), @@ -146,7 +149,7 @@ fn request_params_to_requests( connector: &Connector, request_params: Vec, original_request: &connect::Request, - debug: &mut Option, + debug: &Option>>, ) -> Result, ResponseKey)>, MakeRequestError> { let mut results = vec![]; @@ -1280,7 +1283,7 @@ mod tests { config: Default::default(), }; - let requests = super::make_requests(req, &connector, &mut None).unwrap(); + let requests = super::make_requests(req, &connector, &None).unwrap(); assert_debug_snapshot!(requests, @r###" [ diff --git a/apollo-router/src/plugins/connectors/plugin.rs b/apollo-router/src/plugins/connectors/plugin.rs index 3593d0f5c5..c3abfd0cb4 100644 --- a/apollo-router/src/plugins/connectors/plugin.rs +++ b/apollo-router/src/plugins/connectors/plugin.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use apollo_federation::sources::connect::ApplyToError; use bytes::Bytes; use futures::future::ready; @@ -5,6 +7,7 @@ use futures::stream::once; use futures::StreamExt; use http::HeaderValue; use itertools::Itertools; +use parking_lot::Mutex; use serde::Deserialize; use serde::Serialize; use serde_json_bytes::json; @@ -57,7 +60,9 @@ impl Plugin for Connectors { == Some(&HeaderValue::from_static("true")); if is_enabled { req.context.extensions().with_lock(|mut lock| { - lock.insert::(ConnectorContext::default()); + lock.insert::>>(Arc::new(Mutex::new( + ConnectorContext::default(), + ))); }); } @@ -69,19 +74,21 @@ impl Plugin for Connectors { res = match res { Ok(mut res) => { if is_enabled { - if let Some(debug) = res - .context - .extensions() - .with_lock(|mut lock| lock.remove::()) + if let Some(debug) = + res.context.extensions().with_lock(|mut lock| { + lock.remove::>>() + }) { let (parts, stream) = res.response.into_parts(); let (mut first, rest) = stream.into_future().await; if let Some(first) = &mut first { - first.extensions.insert( - "apolloConnectorsDebugging", - json!({"version": "1", "data": debug.serialize() }), - ); + if let Some(inner) = Arc::into_inner(debug) { + first.extensions.insert( + "apolloConnectorsDebugging", + json!({"version": "1", "data": inner.into_inner().serialize() }), + ); + } } res.response = http::Response::from_parts( parts, diff --git a/apollo-router/src/services/connector_service.rs b/apollo-router/src/services/connector_service.rs index 0eaaa0a3ed..ee88ae08e3 100644 --- a/apollo-router/src/services/connector_service.rs +++ b/apollo-router/src/services/connector_service.rs @@ -8,6 +8,7 @@ use apollo_federation::sources::connect::Connector; use futures::future::BoxFuture; use indexmap::IndexMap; use opentelemetry::Key; +use parking_lot::Mutex; use tower::BoxError; use tower::ServiceExt; use tracing::Instrument; @@ -125,21 +126,18 @@ async fn execute( schema: &Valid, ) -> Result { let context = request.context.clone(); - let context2 = context.clone(); let original_subgraph_name = connector.id.subgraph_name.to_string(); - let mut debug = context + let debug = context .extensions() - .with_lock(|mut lock| lock.remove::()); + .with_lock(|lock| lock.get::>>().cloned()); - let requests = make_requests(request, connector, &mut debug).map_err(BoxError::from)?; + let requests = make_requests(request, connector, &debug).map_err(BoxError::from)?; let tasks = requests.into_iter().map(move |(req, key)| { let context = context.clone(); let original_subgraph_name = original_subgraph_name.clone(); async move { - let context = context.clone(); - let client = http_client_factory.create(&original_subgraph_name); let req = HttpRequest { http_request: req, @@ -158,17 +156,9 @@ async fn execute( .await .map_err(BoxError::from)?; - let result = handle_responses(responses, connector, &mut debug, schema) + handle_responses(responses, connector, &debug, schema) .await - .map_err(BoxError::from); - - if let Some(debug) = debug { - context2 - .extensions() - .with_lock(|mut lock| lock.insert::(debug)); - } - - result + .map_err(BoxError::from) } #[derive(Clone)]