diff --git a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap index 8466cefdfa..41691dbb59 100644 --- a/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap +++ b/apollo-router/src/configuration/snapshots/apollo_router__configuration__tests__schema_generation.snap @@ -115,6 +115,21 @@ expression: "&schema" } ] }, + "ActualCostComputationMode": { + "oneOf": [ + { + "enum": [ + "by_subgraph" + ], + "type": "string" + }, + { + "const": "legacy", + "deprecated": true, + "type": "string" + } + ] + }, "All": { "enum": [ "all" @@ -9128,6 +9143,14 @@ expression: "&schema" "static_estimated": { "additionalProperties": false, "properties": { + "actual_cost_computation_mode": { + "allOf": [ + { + "$ref": "#/definitions/ActualCostComputationMode" + } + ], + "description": "The strategy used to calculate the actual cost incurred by an operation.\n\n* `by_subgraph` (default) computes the cost of each subgraph response and sums them\n to get the total query cost.\n* `legacy` computes the cost based on the final structure of the composed response, not\n including any interim structures from subgraph responses that did not make it to the\n composed response." + }, "list_size": { "description": "The assumed length of lists returned by the operation.", "format": "uint32", @@ -9138,6 +9161,21 @@ expression: "&schema" "description": "The maximum cost of a query", "format": "double", "type": "number" + }, + "subgraphs": { + "allOf": [ + { + "$ref": "#/definitions/SubgraphSubgraphStrategyLimitConfiguration" + } + ], + "default": { + "all": { + "list_size": null, + "max": null + }, + "subgraphs": {} + }, + "description": "Cost control by subgraph" } }, "required": [ @@ -10409,6 +10447,28 @@ expression: "&schema" }, "type": "object" }, + "SubgraphStrategyLimit": { + "properties": { + "list_size": { + "description": "The assumed length of lists returned by the operation for this subgraph.", + "format": "uint32", + "minimum": 0, + "type": [ + "integer", + "null" + ] + }, + "max": { + "description": "The maximum query cost routed to this subgraph.", + "format": "double", + "type": [ + "number", + "null" + ] + } + }, + "type": "object" + }, "SubgraphSubgraphApqConfiguration": { "description": "Configuration options pertaining to the subgraph server component.", "properties": { @@ -10492,6 +10552,32 @@ expression: "&schema" }, "type": "object" }, + "SubgraphSubgraphStrategyLimitConfiguration": { + "description": "Configuration options pertaining to the subgraph server component.", + "properties": { + "all": { + "allOf": [ + { + "$ref": "#/definitions/SubgraphStrategyLimit" + } + ], + "default": { + "list_size": null, + "max": null + }, + "description": "options applying to all subgraphs" + }, + "subgraphs": { + "additionalProperties": { + "$ref": "#/definitions/SubgraphStrategyLimit" + }, + "default": {}, + "description": "per subgraph options", + "type": "object" + } + }, + "type": "object" + }, "SubgraphTlsClientConfiguration": { "description": "Configuration options pertaining to the subgraph server component.", "properties": { diff --git a/apollo-router/src/configuration/subgraph.rs b/apollo-router/src/configuration/subgraph.rs index a371f4ebe2..c63561eca4 100644 --- a/apollo-router/src/configuration/subgraph.rs +++ b/apollo-router/src/configuration/subgraph.rs @@ -96,6 +96,21 @@ where pub(crate) fn get(&self, subgraph_name: &str) -> &T { self.subgraphs.get(subgraph_name).unwrap_or(&self.all) } + + // Create a new `SubgraphConfiguration` by extracting a value `V` from `T` + pub(crate) fn extract( + &self, + extract_fn: fn(&T) -> V, + ) -> SubgraphConfiguration { + SubgraphConfiguration { + all: extract_fn(&self.all), + subgraphs: self + .subgraphs + .iter() + .map(|(k, v)| (k.clone(), extract_fn(v))) + .collect(), + } + } } impl Debug for SubgraphConfiguration diff --git a/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs b/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs index 70595ab172..97fdaff64a 100644 --- a/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs +++ b/apollo-router/src/plugins/demand_control/cost_calculator/static_cost.rs @@ -1,3 +1,4 @@ +use std::ops::AddAssign; use std::sync::Arc; use ahash::HashMap; @@ -12,6 +13,8 @@ use apollo_compiler::executable::Selection; use apollo_compiler::executable::SelectionSet; use apollo_compiler::schema::ExtendedType; use apollo_federation::query_plan::serializable_document::SerializableDocument; +use serde::Deserialize; +use serde::Serialize; use serde_json_bytes::Value; use super::DemandControlError; @@ -19,6 +22,7 @@ use super::directives::IncludeDirective; use super::directives::SkipDirective; use super::schema::DemandControlledSchema; use super::schema::InputDefinition; +use crate::configuration::subgraph::SubgraphConfiguration; use crate::graphql::Response; use crate::graphql::ResponseVisitor; use crate::json_ext::Object; @@ -29,8 +33,51 @@ use crate::query_planner::Primary; use crate::query_planner::QueryPlan; use crate::spec::TYPENAME; +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub(crate) struct CostBySubgraph(HashMap); +impl CostBySubgraph { + pub(crate) fn new(subgraph: String, value: f64) -> Self { + let mut cost = Self::default(); + cost.insert(subgraph, value); + cost + } + + pub(crate) fn get(&self, subgraph: &str) -> Option { + self.0.get(subgraph).copied() + } + + fn insert(&mut self, subgraph: String, value: f64) { + self.0.insert(subgraph, value); + } + + pub(crate) fn total(&self) -> f64 { + self.0.values().sum() + } + + fn max_by_subgraph(mut self, other: Self) -> Self { + for (subgraph, value) in other.0 { + if let Some(existing_value) = self.get(&subgraph) { + self.insert(subgraph, existing_value.max(value)); + } else { + self.insert(subgraph, value); + } + } + self + } +} + +impl AddAssign for CostBySubgraph { + fn add_assign(&mut self, rhs: Self) { + for (subgraph, value) in rhs.0 { + let entry = self.0.entry(subgraph).or_default(); + *entry += value; + } + } +} + pub(crate) struct StaticCostCalculator { list_size: u32, + subgraph_list_sizes: Arc>>, supergraph_schema: Arc, subgraph_schemas: Arc>, } @@ -149,14 +196,20 @@ impl StaticCostCalculator { supergraph_schema: Arc, subgraph_schemas: Arc>, list_size: u32, + subgraph_list_sizes: Arc>>, ) -> Self { Self { list_size, + subgraph_list_sizes, supergraph_schema, subgraph_schemas, } } + fn subgraph_list_size(&self, subgraph_name: &str) -> Option { + *self.subgraph_list_sizes.get(subgraph_name) + } + /// Scores a field within a GraphQL operation, handling some expected cases where /// directives change how the query is fetched. In the case of the federation /// directive `@requires`, the cost of the required selection is added to the @@ -181,6 +234,7 @@ impl StaticCostCalculator { field: &Field, parent_type: &NamedType, list_size_from_upstream: Option, + subgraph: Option<&str>, ) -> Result { // When we pre-process the schema, __typename isn't included. So, we short-circuit here to avoid failed lookups. if field.name == TYPENAME { @@ -213,6 +267,10 @@ impl StaticCostCalculator { .and_then(|dir| dir.expected_size) { expected_size + } else if let Some(subgraph) = subgraph + && let Some(subgraph_list_size) = self.subgraph_list_size(subgraph) + { + subgraph_list_size as i32 } else { self.list_size as i32 }; @@ -234,6 +292,7 @@ impl StaticCostCalculator { &field.selection_set, field.ty().inner_named_type(), list_size_directive.as_ref(), + subgraph, )?; let mut arguments_cost = 0.0; @@ -265,6 +324,7 @@ impl StaticCostCalculator { selection_set, parent_type, list_size_directive.as_ref(), + subgraph, )?; } } @@ -288,6 +348,7 @@ impl StaticCostCalculator { ctx: &ScoringContext, fragment_spread: &FragmentSpread, list_size_directive: Option<&ListSizeDirective>, + subgraph: Option<&str>, ) -> Result { let fragment = fragment_spread.fragment_def(ctx.query).ok_or_else(|| { DemandControlError::QueryParseFailure(format!( @@ -300,6 +361,7 @@ impl StaticCostCalculator { &fragment.selection_set, fragment.type_condition(), list_size_directive, + subgraph, ) } @@ -309,6 +371,7 @@ impl StaticCostCalculator { inline_fragment: &InlineFragment, parent_type: &NamedType, list_size_directive: Option<&ListSizeDirective>, + subgraph: Option<&str>, ) -> Result { self.score_selection_set( ctx, @@ -318,6 +381,7 @@ impl StaticCostCalculator { .as_ref() .unwrap_or(parent_type), list_size_directive, + subgraph, ) } @@ -325,6 +389,7 @@ impl StaticCostCalculator { &self, operation: &Operation, ctx: &ScoringContext, + subgraph: Option<&str>, ) -> Result { let mut cost = if operation.is_mutation() { 10.0 } else { 0.0 }; @@ -335,7 +400,13 @@ impl StaticCostCalculator { ))); }; - cost += self.score_selection_set(ctx, &operation.selection_set, root_type_name, None)?; + cost += self.score_selection_set( + ctx, + &operation.selection_set, + root_type_name, + None, + subgraph, + )?; Ok(cost) } @@ -346,6 +417,7 @@ impl StaticCostCalculator { selection: &Selection, parent_type: &NamedType, list_size_directive: Option<&ListSizeDirective>, + subgraph: Option<&str>, ) -> Result { match selection { Selection::Field(f) => self.score_field( @@ -353,10 +425,13 @@ impl StaticCostCalculator { f, parent_type, list_size_directive.and_then(|dir| dir.size_of(f)), + subgraph, ), - Selection::FragmentSpread(s) => self.score_fragment_spread(ctx, s, list_size_directive), + Selection::FragmentSpread(s) => { + self.score_fragment_spread(ctx, s, list_size_directive, subgraph) + } Selection::InlineFragment(i) => { - self.score_inline_fragment(ctx, i, parent_type, list_size_directive) + self.score_inline_fragment(ctx, i, parent_type, list_size_directive, subgraph) } } } @@ -367,10 +442,17 @@ impl StaticCostCalculator { selection_set: &SelectionSet, parent_type_name: &NamedType, list_size_directive: Option<&ListSizeDirective>, + subgraph: Option<&str>, ) -> Result { let mut cost = 0.0; for selection in selection_set.selections.iter() { - cost += self.score_selection(ctx, selection, parent_type_name, list_size_directive)?; + cost += self.score_selection( + ctx, + selection, + parent_type_name, + list_size_directive, + subgraph, + )?; } Ok(cost) } @@ -393,7 +475,7 @@ impl StaticCostCalculator { &self, plan_node: &PlanNode, variables: &Object, - ) -> Result { + ) -> Result { match plan_node { PlanNode::Sequence { nodes } => self.summed_score_of_nodes(nodes, variables), PlanNode::Parallel { nodes } => self.summed_score_of_nodes(nodes, variables), @@ -424,7 +506,7 @@ impl StaticCostCalculator { subgraph: &str, operation: &SerializableDocument, variables: &Object, - ) -> Result { + ) -> Result { tracing::debug!("On subgraph {}, scoring operation: {}", subgraph, operation); let schema = self.subgraph_schemas.get(subgraph).ok_or_else(|| { @@ -436,7 +518,8 @@ impl StaticCostCalculator { let operation = operation .as_parsed() .map_err(DemandControlError::SubgraphOperationNotInitialized)?; - self.estimated(operation, schema, variables, false) + let estimated_cost = self.estimated(operation, schema, variables, false, Some(subgraph))?; + Ok(CostBySubgraph::new(subgraph.to_string(), estimated_cost)) } fn max_score_of_nodes( @@ -444,15 +527,15 @@ impl StaticCostCalculator { left: &Option>, right: &Option>, variables: &Object, - ) -> Result { + ) -> Result { match (left, right) { - (None, None) => Ok(0.0), + (None, None) => Ok(CostBySubgraph::default()), (None, Some(right)) => self.score_plan_node(right, variables), (Some(left), None) => self.score_plan_node(left, variables), (Some(left), Some(right)) => { let left_score = self.score_plan_node(left, variables)?; let right_score = self.score_plan_node(right, variables)?; - Ok(left_score.max(right_score)) + Ok(left_score.max_by_subgraph(right_score)) } } } @@ -462,8 +545,8 @@ impl StaticCostCalculator { primary: &Primary, deferred: &Vec, variables: &Object, - ) -> Result { - let mut score = 0.0; + ) -> Result { + let mut score = CostBySubgraph::default(); if let Some(node) = &primary.node { score += self.score_plan_node(node, variables)?; } @@ -479,8 +562,8 @@ impl StaticCostCalculator { &self, nodes: &Vec, variables: &Object, - ) -> Result { - let mut sum = 0.0; + ) -> Result { + let mut sum = CostBySubgraph::default(); for node in nodes { sum += self.score_plan_node(node, variables)?; } @@ -493,6 +576,7 @@ impl StaticCostCalculator { schema: &DemandControlledSchema, variables: &Object, should_estimate_requires: bool, + subgraph: Option<&str>, ) -> Result { let mut cost = 0.0; let ctx = ScoringContext { @@ -502,10 +586,10 @@ impl StaticCostCalculator { should_estimate_requires, }; if let Some(op) = &query.operations.anonymous { - cost += self.score_operation(op, &ctx)?; + cost += self.score_operation(op, &ctx, subgraph)?; } for (_name, op) in query.operations.named.iter() { - cost += self.score_operation(op, &ctx)?; + cost += self.score_operation(op, &ctx, subgraph)?; } Ok(cost) } @@ -514,7 +598,7 @@ impl StaticCostCalculator { &self, query_plan: &QueryPlan, variables: &Object, - ) -> Result { + ) -> Result { self.score_plan_node(&query_plan.root, variables) } @@ -523,8 +607,9 @@ impl StaticCostCalculator { request: &ExecutableDocument, response: &Response, variables: &Object, + include_entities: bool, ) -> Result { - let mut visitor = ResponseCostCalculator::new(&self.supergraph_schema); + let mut visitor = ResponseCostCalculator::new(&self.supergraph_schema, include_entities); visitor.visit(request, response, variables); Ok(visitor.cost) } @@ -533,11 +618,16 @@ impl StaticCostCalculator { pub(crate) struct ResponseCostCalculator<'a> { pub(crate) cost: f64, schema: &'a DemandControlledSchema, + include_entities: bool, } impl<'schema> ResponseCostCalculator<'schema> { - pub(crate) fn new(schema: &'schema DemandControlledSchema) -> Self { - Self { cost: 0.0, schema } + pub(crate) fn new(schema: &'schema DemandControlledSchema, include_entities: bool) -> Self { + Self { + cost: 0.0, + schema, + include_entities, + } } fn score_response_field( @@ -553,53 +643,61 @@ impl<'schema> ResponseCostCalculator<'schema> { if field.name == TYPENAME { return; } - if let Some(definition) = self.schema.output_field_definition(parent_ty, &field.name) { - match value { - Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => { - self.cost += definition - .cost_directive() - .map_or(0.0, |cost| cost.weight()); - } - Value::Array(items) => { - for item in items { - self.visit_list_item(request, variables, parent_ty, field, item); - } - } - Value::Object(children) => { - self.cost += definition - .cost_directive() - .map_or(1.0, |cost| cost.weight()); - self.visit_selections(request, variables, &field.selection_set, children); + + let definition = self.schema.output_field_definition(parent_ty, &field.name); + + // if the definition is None, that means one of two things: + // (1) the field is missing from the schema, or + // (2) the query is an `_entities` query. + // If the field _should_ be there and isn't, or we don't want to score entities, return now. + let is_entities_query = parent_ty == "Query" && field.name == "_entities"; + if definition.is_none() && !(is_entities_query && self.include_entities) { + tracing::debug!( + "Failed to get schema definition for field {}.{}. The resulting response cost will be a partial result.", + parent_ty, + field.name, + ); + return; + } + + match value { + Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => { + self.cost += definition + .and_then(|d| d.cost_directive()) + .map_or(0.0, |cost| cost.weight()); + } + Value::Array(items) => { + for item in items { + self.visit_list_item(request, variables, parent_ty, field, item); } } + Value::Object(children) => { + self.cost += definition + .and_then(|d| d.cost_directive()) + .map_or(1.0, |cost| cost.weight()); + self.visit_selections(request, variables, &field.selection_set, children); + } + } - if include_argument_score { - for argument in &field.arguments { - if let Some(argument_definition) = definition.argument_by_name(&argument.name) { - if let Ok(score) = score_argument( - &argument.value, - argument_definition, - self.schema, - variables, - ) { - self.cost += score; - } + if include_argument_score && let Some(definition) = definition { + for argument in &field.arguments { + if let Some(argument_definition) = definition.argument_by_name(&argument.name) { + if let Ok(score) = + score_argument(&argument.value, argument_definition, self.schema, variables) + { + self.cost += score; } else { - tracing::debug!( - "Failed to get schema definition for argument {}.{}({}:). The resulting response cost will be a partial result.", - parent_ty, - field.name, - argument.name, - ) + eprintln!("argument score is none"); } + } else { + tracing::debug!( + "Failed to get schema definition for argument {}.{}({}:). The resulting response cost will be a partial result.", + parent_ty, + field.name, + argument.name, + ) } } - } else { - tracing::debug!( - "Failed to get schema definition for field {}.{}. The resulting response cost will be a partial result.", - parent_ty, - field.name, - ) } } } @@ -658,7 +756,7 @@ mod tests { &self, query_plan: &apollo_federation::query_plan::QueryPlan, variables: &Object, - ) -> Result { + ) -> Result { let js_planner_node: PlanNode = query_plan.node.as_ref().unwrap().into(); self.score_plan_node(&js_planner_node, variables) } @@ -685,7 +783,12 @@ mod tests { .unwrap_or_default(); let schema = DemandControlledSchema::new(Arc::new(schema.supergraph_schema().clone())).unwrap(); - let calculator = StaticCostCalculator::new(Arc::new(schema), Default::default(), 100); + let calculator = StaticCostCalculator::new( + Arc::new(schema), + Default::default(), + 100, + Default::default(), + ); calculator .estimated( @@ -693,6 +796,7 @@ mod tests { &calculator.supergraph_schema, &variables, true, + None, ) .unwrap() } @@ -713,14 +817,29 @@ mod tests { .cloned() .unwrap_or_default(); let schema = DemandControlledSchema::new(Arc::new(schema)).unwrap(); - let calculator = StaticCostCalculator::new(Arc::new(schema), Default::default(), 100); + let calculator = StaticCostCalculator::new( + Arc::new(schema), + Default::default(), + 100, + Default::default(), + ); calculator - .estimated(&query, &calculator.supergraph_schema, &variables, true) + .estimated( + &query, + &calculator.supergraph_schema, + &variables, + true, + None, + ) .unwrap() } - async fn planned_cost_js(schema_str: &str, query_str: &str, variables_str: &str) -> f64 { + async fn planned_cost_js( + schema_str: &str, + query_str: &str, + variables_str: &str, + ) -> CostBySubgraph { let config: Arc = Arc::new(Default::default()); let (schema, query) = parse_schema_and_operation(schema_str, query_str, &config); let variables = serde_json::from_str::(variables_str) @@ -767,12 +886,13 @@ mod tests { Arc::new(schema), Arc::new(demand_controlled_subgraph_schemas), 100, + Default::default(), ); calculator.planned(&query_plan, &variables).unwrap() } - fn planned_cost_rust(schema_str: &str, query_str: &str, variables_str: &str) -> f64 { + fn planned_cost_rust(schema_str: &str, query_str: &str, variables_str: &str) -> CostBySubgraph { let config: Arc = Arc::new(Default::default()); let (schema, query) = parse_schema_and_operation(schema_str, query_str, &config); let variables = serde_json::from_str::(variables_str) @@ -802,6 +922,7 @@ mod tests { Arc::new(schema), Arc::new(demand_controlled_subgraph_schemas), 100, + Default::default(), ); calculator.rust_planned(&query_plan, &variables).unwrap() @@ -823,9 +944,14 @@ mod tests { let response = Response::from_bytes(Bytes::from(response_bytes)).unwrap(); let schema = DemandControlledSchema::new(Arc::new(schema.supergraph_schema().clone())).unwrap(); - StaticCostCalculator::new(Arc::new(schema), Default::default(), 100) - .actual(&query.executable, &response, &variables) - .unwrap() + StaticCostCalculator::new( + Arc::new(schema), + Default::default(), + 100, + Default::default(), + ) + .actual(&query.executable, &response, &variables, false) + .unwrap() } /// Actual cost of an operation on a plain, non-federated schema. @@ -851,9 +977,14 @@ mod tests { let response = Response::from_bytes(Bytes::from(response_bytes)).unwrap(); let schema = DemandControlledSchema::new(Arc::new(schema)).unwrap(); - StaticCostCalculator::new(Arc::new(schema), Default::default(), 100) - .actual(&query, &response, &variables) - .unwrap() + StaticCostCalculator::new( + Arc::new(schema), + Default::default(), + 100, + Default::default(), + ) + .actual(&query, &response, &variables, false) + .unwrap() } #[test] @@ -974,8 +1105,14 @@ mod tests { let variables = "{}"; assert_eq!(basic_estimated_cost(schema, query, variables), 102.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 102.0); - assert_eq!(planned_cost_rust(schema, query, variables), 102.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 102.0); + assert_eq!(cost_js.get("products").unwrap(), 102.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 102.0); + assert_eq!(cost_rust.get("products").unwrap(), 102.0); } #[test(tokio::test)] @@ -997,8 +1134,17 @@ mod tests { let response = include_bytes!("./fixtures/federated_ships_required_response.json"); assert_eq!(estimated_cost(schema, query, variables), 10200.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 10400.0); - assert_eq!(planned_cost_rust(schema, query, variables), 10400.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 10400.0); + assert_eq!(cost_js.get("users").unwrap(), 10100.0); + assert_eq!(cost_js.get("vehicles").unwrap(), 300.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 10400.0); + assert_eq!(cost_rust.get("users").unwrap(), 10100.0); + assert_eq!(cost_rust.get("vehicles").unwrap(), 300.0); + assert_eq!(actual_cost(schema, query, variables, response), 2.0); } @@ -1010,8 +1156,17 @@ mod tests { let response = include_bytes!("./fixtures/federated_ships_fragment_response.json"); assert_eq!(estimated_cost(schema, query, variables), 300.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 400.0); - assert_eq!(planned_cost_rust(schema, query, variables), 400.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 400.0); + assert_eq!(cost_js.get("users").unwrap(), 200.0); + assert_eq!(cost_js.get("vehicles").unwrap(), 200.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 400.0); + assert_eq!(cost_rust.get("users").unwrap(), 200.0); + assert_eq!(cost_rust.get("vehicles").unwrap(), 200.0); + assert_eq!(actual_cost(schema, query, variables, response), 6.0); } @@ -1023,8 +1178,17 @@ mod tests { let response = include_bytes!("./fixtures/federated_ships_fragment_response.json"); assert_eq!(estimated_cost(schema, query, variables), 300.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 400.0); - assert_eq!(planned_cost_rust(schema, query, variables), 400.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 400.0); + assert_eq!(cost_js.get("users").unwrap(), 200.0); + assert_eq!(cost_js.get("vehicles").unwrap(), 200.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 400.0); + assert_eq!(cost_rust.get("users").unwrap(), 200.0); + assert_eq!(cost_rust.get("vehicles").unwrap(), 200.0); + assert_eq!(actual_cost(schema, query, variables, response), 6.0); } @@ -1036,8 +1200,17 @@ mod tests { let response = include_bytes!("./fixtures/federated_ships_deferred_response.json"); assert_eq!(estimated_cost(schema, query, variables), 10200.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 10400.0); - assert_eq!(planned_cost_rust(schema, query, variables), 10400.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 10400.0); + assert_eq!(cost_js.get("users").unwrap(), 10100.0); + assert_eq!(cost_js.get("vehicles").unwrap(), 300.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 10400.0); + assert_eq!(cost_rust.get("users").unwrap(), 10100.0); + assert_eq!(cost_rust.get("vehicles").unwrap(), 300.0); + assert_eq!(actual_cost(schema, query, variables, response), 2.0); } @@ -1050,23 +1223,27 @@ mod tests { DemandControlledSchema::new(Arc::new(schema.supergraph_schema().clone())).unwrap(), ); - let calculator = StaticCostCalculator::new(schema.clone(), Default::default(), 100); + let calculator = + StaticCostCalculator::new(schema.clone(), Default::default(), 100, Default::default()); let conservative_estimate = calculator .estimated( &query.executable, &calculator.supergraph_schema, &Default::default(), true, + None, ) .unwrap(); - let calculator = StaticCostCalculator::new(schema.clone(), Default::default(), 5); + let calculator = + StaticCostCalculator::new(schema.clone(), Default::default(), 5, Default::default()); let narrow_estimate = calculator .estimated( &query.executable, &calculator.supergraph_schema, &Default::default(), true, + None, ) .unwrap(); @@ -1098,8 +1275,17 @@ mod tests { let response = include_bytes!("./fixtures/custom_cost_response.json"); assert_eq!(estimated_cost(schema, query, variables), 127.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 127.0); - assert_eq!(planned_cost_rust(schema, query, variables), 127.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 127.0); + assert_eq!(cost_js.get("subgraphWithListSize").unwrap(), 6.0); + assert_eq!(cost_js.get("subgraphWithCost").unwrap(), 121.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 127.0); + assert_eq!(cost_rust.get("subgraphWithListSize").unwrap(), 6.0); + assert_eq!(cost_rust.get("subgraphWithCost").unwrap(), 121.0); + assert_eq!(actual_cost(schema, query, variables, response), 125.0); } @@ -1111,8 +1297,17 @@ mod tests { let response = include_bytes!("./fixtures/custom_cost_response.json"); assert_eq!(estimated_cost(schema, query, variables), 127.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 127.0); - assert_eq!(planned_cost_rust(schema, query, variables), 127.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 127.0); + assert_eq!(cost_js.get("subgraphWithListSize").unwrap(), 6.0); + assert_eq!(cost_js.get("subgraphWithCost").unwrap(), 121.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 127.0); + assert_eq!(cost_rust.get("subgraphWithListSize").unwrap(), 6.0); + assert_eq!(cost_rust.get("subgraphWithCost").unwrap(), 121.0); + assert_eq!(actual_cost(schema, query, variables, response), 125.0); } @@ -1125,8 +1320,17 @@ mod tests { let response = include_bytes!("./fixtures/custom_cost_response.json"); assert_eq!(estimated_cost(schema, query, variables), 132.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 132.0); - assert_eq!(planned_cost_rust(schema, query, variables), 132.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 132.0); + assert_eq!(cost_js.get("subgraphWithListSize").unwrap(), 11.0); + assert_eq!(cost_js.get("subgraphWithCost").unwrap(), 121.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 132.0); + assert_eq!(cost_rust.get("subgraphWithListSize").unwrap(), 11.0); + assert_eq!(cost_rust.get("subgraphWithCost").unwrap(), 121.0); + assert_eq!(actual_cost(schema, query, variables, response), 125.0); } @@ -1139,8 +1343,17 @@ mod tests { let response = include_bytes!("./fixtures/custom_cost_response.json"); assert_eq!(estimated_cost(schema, query, variables), 127.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 127.0); - assert_eq!(planned_cost_rust(schema, query, variables), 127.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 127.0); + assert_eq!(cost_js.get("subgraphWithListSize").unwrap(), 6.0); + assert_eq!(cost_js.get("subgraphWithCost").unwrap(), 121.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 127.0); + assert_eq!(cost_rust.get("subgraphWithListSize").unwrap(), 6.0); + assert_eq!(cost_rust.get("subgraphWithCost").unwrap(), 121.0); + assert_eq!(actual_cost(schema, query, variables, response), 125.0); } @@ -1172,7 +1385,13 @@ mod tests { let variables = "{}"; assert_eq!(estimated_cost(schema, query, variables), 1.0); - assert_eq!(planned_cost_js(schema, query, variables).await, 1.0); - assert_eq!(planned_cost_rust(schema, query, variables), 1.0); + + let cost_js = planned_cost_js(schema, query, variables).await; + assert_eq!(cost_js.total(), 1.0); + assert_eq!(cost_js.get("subgraph").unwrap(), 1.0); + + let cost_rust = planned_cost_rust(schema, query, variables); + assert_eq!(cost_rust.total(), 1.0); + assert_eq!(cost_rust.get("subgraph").unwrap(), 1.0); } } diff --git a/apollo-router/src/plugins/demand_control/fixtures/invalid_per_subgraph.yaml b/apollo-router/src/plugins/demand_control/fixtures/invalid_per_subgraph.yaml new file mode 100644 index 0000000000..8ddabcb984 --- /dev/null +++ b/apollo-router/src/plugins/demand_control/fixtures/invalid_per_subgraph.yaml @@ -0,0 +1,13 @@ +demand_control: + enabled: true + mode: enforce + strategy: + static_estimated: + list_size: 1 + max: 10 + subgraphs: + all: + list_size: 3 + subgraphs: + products: + max: -1 diff --git a/apollo-router/src/plugins/demand_control/fixtures/per_subgraph_inheritance.yaml b/apollo-router/src/plugins/demand_control/fixtures/per_subgraph_inheritance.yaml new file mode 100644 index 0000000000..57d4acc553 --- /dev/null +++ b/apollo-router/src/plugins/demand_control/fixtures/per_subgraph_inheritance.yaml @@ -0,0 +1,16 @@ +demand_control: + enabled: true + mode: enforce + strategy: + static_estimated: + list_size: 1 + max: 10 + subgraphs: + all: + list_size: 3 + max: 5 + subgraphs: + products: + list_size: 5 + reviews: + max: 2 diff --git a/apollo-router/src/plugins/demand_control/fixtures/per_subgraph_no_inheritance.yaml b/apollo-router/src/plugins/demand_control/fixtures/per_subgraph_no_inheritance.yaml new file mode 100644 index 0000000000..bebb4fbf7c --- /dev/null +++ b/apollo-router/src/plugins/demand_control/fixtures/per_subgraph_no_inheritance.yaml @@ -0,0 +1,15 @@ +demand_control: + enabled: true + mode: enforce + strategy: + static_estimated: + list_size: 1 + max: 10 + subgraphs: + all: + list_size: 3 + subgraphs: + products: + list_size: 5 + reviews: + max: 2 diff --git a/apollo-router/src/plugins/demand_control/mod.rs b/apollo-router/src/plugins/demand_control/mod.rs index 4314e8cd1f..2a2899b1a6 100644 --- a/apollo-router/src/plugins/demand_control/mod.rs +++ b/apollo-router/src/plugins/demand_control/mod.rs @@ -1,6 +1,8 @@ //! Demand control plugin. //! This plugin will use the cost calculation algorithm to determine if a query should be allowed to execute. //! On the request path it will use estimated + +use std::collections::HashSet; use std::future; use std::ops::ControlFlow; use std::sync::Arc; @@ -27,6 +29,7 @@ use tower::ServiceBuilder; use tower::ServiceExt; use crate::Context; +use crate::configuration::subgraph::SubgraphConfiguration; use crate::error::Error; use crate::graphql; use crate::graphql::IntoGraphQLErrors; @@ -35,6 +38,7 @@ use crate::layers::ServiceBuilderExt; use crate::plugin::Plugin; use crate::plugin::PluginInit; use crate::plugins::demand_control::cost_calculator::schema::DemandControlledSchema; +use crate::plugins::demand_control::cost_calculator::static_cost::CostBySubgraph; use crate::plugins::demand_control::strategy::Strategy; use crate::plugins::demand_control::strategy::StrategyFactory; use crate::plugins::telemetry::tracing::apollo_telemetry::emit_error_event; @@ -51,6 +55,12 @@ pub(crate) const COST_ACTUAL_KEY: &str = "apollo::demand_control::actual_cost"; pub(crate) const COST_RESULT_KEY: &str = "apollo::demand_control::result"; pub(crate) const COST_STRATEGY_KEY: &str = "apollo::demand_control::strategy"; +pub(crate) const COST_BY_SUBGRAPH_ESTIMATED_KEY: &str = + "apollo::demand_control::estimated_cost_by_subgraph"; +pub(crate) const COST_BY_SUBGRAPH_ACTUAL_KEY: &str = + "apollo::demand_control::actual_cost_by_subgraph"; +pub(crate) const COST_BY_SUBGRAPH_RESULT_KEY: &str = "apollo::demand_control::result_by_subgraph"; + /// Algorithm for calculating the cost of an incoming query. #[derive(Clone, Debug, Deserialize, JsonSchema)] #[serde(deny_unknown_fields, rename_all = "snake_case")] @@ -73,6 +83,20 @@ pub(crate) enum StrategyConfig { list_size: u32, /// The maximum cost of a query max: f64, + + /// The strategy used to calculate the actual cost incurred by an operation. + /// + /// * `by_subgraph` (default) computes the cost of each subgraph response and sums them + /// to get the total query cost. + /// * `legacy` computes the cost based on the final structure of the composed response, not + /// including any interim structures from subgraph responses that did not make it to the + /// composed response. + #[serde(default)] + actual_cost_computation_mode: ActualCostComputationMode, + + /// Cost control by subgraph + #[serde(default)] + subgraphs: SubgraphConfiguration, }, #[cfg(test)] @@ -82,6 +106,76 @@ pub(crate) enum StrategyConfig { }, } +#[derive(Copy, Clone, Debug, Default, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub(crate) enum ActualCostComputationMode { + #[default] + BySubgraph, + + #[deprecated(since = "TBD", note = "use `BySubgraph` instead")] + #[warn(deprecated_in_future)] + Legacy, +} + +#[derive(Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] +pub(crate) struct SubgraphStrategyLimit { + /// The assumed length of lists returned by the operation for this subgraph. + list_size: Option, + + /// The maximum query cost routed to this subgraph. + max: Option, +} + +impl StrategyConfig { + fn validate(&self, subgraph_names: HashSet<&String>) -> Result<(), BoxError> { + #[derive(thiserror::Error, Debug)] + enum Error { + #[error("Maximum per-subgraph query cost for `{0}` is negative")] + NegativeQueryCost(String), + } + + #[allow(irrefutable_let_patterns)] + // need to destructure StrategyConfig::StaticEstimated and ignore StrategyConfig::Test + let StrategyConfig::StaticEstimated { + subgraphs, + actual_cost_computation_mode, + .. + } = self + else { + return Ok(()); + }; + + #[allow(deprecated_in_future)] + if matches!( + actual_cost_computation_mode, + ActualCostComputationMode::Legacy + ) { + tracing::warn!( + "Actual cost computation mode `legacy` will be deprecated in the future; migrate to `by_subgraph` when possible", + ); + } + + if subgraphs.all.max.is_some_and(|s| s < 0.0) { + return Err(Error::NegativeQueryCost("all".to_string()).into()); + } + + for (subgraph_name, subgraph_config) in subgraphs.subgraphs.iter() { + if !subgraph_names.contains(subgraph_name) { + tracing::warn!( + "Subgraph `{subgraph_name}` missing from schema but was specified in per-subgraph demand cost; it will be ignored" + ); + continue; + } + + if subgraph_config.max.is_some_and(|s| s < 0.0) { + return Err(Error::NegativeQueryCost(subgraph_name.to_string()).into()); + } + } + + Ok(()) + } +} + #[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, Eq, PartialEq)] #[serde(deny_unknown_fields, rename_all = "snake_case")] pub(crate) enum Mode { @@ -112,6 +206,15 @@ pub(crate) enum DemandControlError { /// The maximum cost of the query max_cost: f64, }, + /// subgraph {subgraph} query estimated cost {estimated_cost} exceeded configured maximum {max_cost} + EstimatedSubgraphCostTooExpensive { + /// The subgraph in question + subgraph: String, + /// The estimated cost of the query + estimated_cost: f64, + /// The maximum cost of the query + max_cost: f64, + }, /// auery actual cost {actual_cost} exceeded configured maximum {max_cost} #[allow(dead_code)] ActualCostTooExpensive { @@ -148,6 +251,24 @@ impl IntoGraphQLErrors for DemandControlError { .build(), ]) } + DemandControlError::EstimatedSubgraphCostTooExpensive { + ref subgraph, + estimated_cost, + max_cost, + } => { + let mut extensions = Object::new(); + // TODO: better extensions names? + extensions.insert("cost.subgraph", subgraph.clone().into()); + extensions.insert("cost.subgraph.estimated", estimated_cost.into()); + extensions.insert("cost.subgraph.max", max_cost.into()); + Ok(vec![ + graphql::Error::builder() + .extension_code(self.code()) + .extensions(extensions) + .message(self.to_string()) + .build(), + ]) + } DemandControlError::ActualCostTooExpensive { actual_cost, max_cost, @@ -195,6 +316,9 @@ impl DemandControlError { fn code(&self) -> &'static str { match self { DemandControlError::EstimatedCostTooExpensive { .. } => "COST_ESTIMATED_TOO_EXPENSIVE", + DemandControlError::EstimatedSubgraphCostTooExpensive { .. } => { + "SUBGRAPH_COST_ESTIMATED_TOO_EXPENSIVE" + } DemandControlError::ActualCostTooExpensive { .. } => "COST_ACTUAL_TOO_EXPENSIVE", DemandControlError::QueryParseFailure(_) => "COST_QUERY_PARSE_FAILURE", DemandControlError::SubgraphOperationNotInitialized(_) => { @@ -251,6 +375,32 @@ impl Context { .map_err(|e| DemandControlError::ContextSerializationError(e.to_string())) } + pub(crate) fn insert_estimated_cost_by_subgraph( + &self, + cost: CostBySubgraph, + ) -> Result<(), DemandControlError> { + self.insert(COST_BY_SUBGRAPH_ESTIMATED_KEY, cost) + .map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?; + Ok(()) + } + + pub(crate) fn get_estimated_cost_by_subgraph( + &self, + ) -> Result, DemandControlError> { + self.get::<&str, CostBySubgraph>(COST_BY_SUBGRAPH_ESTIMATED_KEY) + .map_err(|e| DemandControlError::ContextSerializationError(e.to_string())) + } + + pub(crate) fn get_estimated_cost_for_subgraph( + &self, + subgraph: &str, + ) -> Result, DemandControlError> { + let cost_by_subgraph_opt = self + .get::<&str, CostBySubgraph>(COST_BY_SUBGRAPH_ESTIMATED_KEY) + .map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?; + Ok(cost_by_subgraph_opt.and_then(|cost_by_subgraph| cost_by_subgraph.get(subgraph))) + } + pub(crate) fn insert_actual_cost(&self, cost: f64) -> Result<(), DemandControlError> { self.insert(COST_ACTUAL_KEY, cost) .map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?; @@ -262,6 +412,29 @@ impl Context { .map_err(|e| DemandControlError::ContextSerializationError(e.to_string())) } + pub(crate) fn update_actual_cost_by_subgraph( + &self, + cost: CostBySubgraph, + ) -> Result<(), DemandControlError> { + // combine this cost with the cost that already exists in the context + self.upsert( + COST_BY_SUBGRAPH_ACTUAL_KEY, + |mut existing_cost: CostBySubgraph| { + existing_cost += cost; + existing_cost + }, + ) + .map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?; + Ok(()) + } + + pub(crate) fn get_actual_cost_by_subgraph( + &self, + ) -> Result, DemandControlError> { + self.get::<&str, CostBySubgraph>(COST_BY_SUBGRAPH_ACTUAL_KEY) + .map_err(|e| DemandControlError::ContextSerializationError(e.to_string())) + } + pub(crate) fn get_cost_delta(&self) -> Result, DemandControlError> { let estimated = self.get_estimated_cost()?; let actual = self.get_actual_cost()?; @@ -279,6 +452,21 @@ impl Context { .map_err(|e| DemandControlError::ContextSerializationError(e.to_string())) } + pub(crate) fn insert_cost_by_subgraph_result( + &self, + subgraph: String, + result: String, + ) -> Result<(), DemandControlError> { + self.upsert( + COST_BY_SUBGRAPH_RESULT_KEY, + |mut map: HashMap| { + map.insert(subgraph, result); + map + }, + ) + .map_err(|e| DemandControlError::ContextSerializationError(e.to_string())) + } + pub(crate) fn insert_cost_strategy(&self, strategy: String) -> Result<(), DemandControlError> { self.insert(COST_STRATEGY_KEY, strategy) .map_err(|e| DemandControlError::ContextSerializationError(e.to_string()))?; @@ -348,6 +536,10 @@ impl Plugin for DemandControl { .insert(subgraph_name.clone(), demand_controlled_subgraph_schema); } + // validate that per-subgraph maxes are all non-negative + let subgraph_names = init.subgraph_schemas.keys().collect(); + init.config.strategy.validate(subgraph_names)?; + Ok(DemandControl { strategy_factory: StrategyFactory::new( init.config.clone(), @@ -501,7 +693,7 @@ impl Plugin for DemandControl { |(subgraph_name, req): (String, Arc>), fut| async move { let resp: subgraph::Response = fut.await?; let strategy = resp.context.get_demand_control_context().map(|c| c.strategy).expect("must have strategy"); - Ok(match strategy.on_subgraph_response(req.as_ref(), &resp) { + Ok(match strategy.on_subgraph_response(subgraph_name.clone(), req.as_ref(), &resp) { Ok(_) => resp, Err(err) => subgraph::Response::builder() .errors( diff --git a/apollo-router/src/plugins/demand_control/strategy/mod.rs b/apollo-router/src/plugins/demand_control/strategy/mod.rs index 738b935b12..f9892eeaea 100644 --- a/apollo-router/src/plugins/demand_control/strategy/mod.rs +++ b/apollo-router/src/plugins/demand_control/strategy/mod.rs @@ -4,11 +4,14 @@ use ahash::HashMap; use apollo_compiler::ExecutableDocument; use crate::Context; +use crate::configuration::subgraph::SubgraphConfiguration; use crate::graphql; +use crate::plugins::demand_control::ActualCostComputationMode; use crate::plugins::demand_control::DemandControlConfig; use crate::plugins::demand_control::DemandControlError; use crate::plugins::demand_control::Mode; use crate::plugins::demand_control::StrategyConfig; +use crate::plugins::demand_control::SubgraphStrategyLimit; use crate::plugins::demand_control::cost_calculator::schema::DemandControlledSchema; use crate::plugins::demand_control::cost_calculator::static_cost::StaticCostCalculator; use crate::plugins::demand_control::strategy::static_estimated::StaticEstimated; @@ -50,10 +53,14 @@ impl Strategy { pub(crate) fn on_subgraph_response( &self, + subgraph_name: String, request: &ExecutableDocument, response: &subgraph::Response, ) -> Result<(), DemandControlError> { - match self.inner.on_subgraph_response(request, response) { + match self + .inner + .on_subgraph_response(subgraph_name, request, response) + { Err(e) if self.mode == Mode::Enforce => Err(e), _ => Ok(()), } @@ -91,16 +98,44 @@ impl StrategyFactory { } } + // Function extracted for use in tests - allows us to build a `StaticEstimated` directly rather + // than a `impl StrategyImpl` + fn create_static_estimated_strategy( + &self, + list_size: u32, + max: f64, + actual_cost_computation_mode: ActualCostComputationMode, + subgraphs: &SubgraphConfiguration, + ) -> StaticEstimated { + let subgraph_list_sizes = Arc::new(subgraphs.extract(|strategy| strategy.list_size)); + let subgraph_maxes = Arc::new(subgraphs.extract(|strategy| strategy.max)); + let cost_calculator = StaticCostCalculator::new( + self.supergraph_schema.clone(), + self.subgraph_schemas.clone(), + list_size, + subgraph_list_sizes, + ); + StaticEstimated { + max, + subgraph_maxes, + actual_cost_computation_mode, + cost_calculator, + } + } + pub(crate) fn create(&self) -> Strategy { let strategy: Arc = match &self.config.strategy { - StrategyConfig::StaticEstimated { list_size, max } => Arc::new(StaticEstimated { - max: *max, - cost_calculator: StaticCostCalculator::new( - self.supergraph_schema.clone(), - self.subgraph_schemas.clone(), - *list_size, - ), - }), + StrategyConfig::StaticEstimated { + list_size, + max, + actual_cost_computation_mode, + subgraphs, + } => Arc::new(self.create_static_estimated_strategy( + *list_size, + *max, + *actual_cost_computation_mode, + subgraphs, + )), #[cfg(test)] StrategyConfig::Test { stage, error } => Arc::new(test::Test { stage: stage.clone(), @@ -120,6 +155,7 @@ pub(crate) trait StrategyImpl: Send + Sync { fn on_subgraph_response( &self, + subgraph_name: String, request: &ExecutableDocument, response: &subgraph::Response, ) -> Result<(), DemandControlError>; diff --git a/apollo-router/src/plugins/demand_control/strategy/static_estimated.rs b/apollo-router/src/plugins/demand_control/strategy/static_estimated.rs index ea442b3a69..c3412cfafe 100644 --- a/apollo-router/src/plugins/demand_control/strategy/static_estimated.rs +++ b/apollo-router/src/plugins/demand_control/strategy/static_estimated.rs @@ -1,7 +1,12 @@ +use std::sync::Arc; + use apollo_compiler::ExecutableDocument; +use crate::configuration::subgraph::SubgraphConfiguration; use crate::graphql; +use crate::plugins::demand_control::ActualCostComputationMode; use crate::plugins::demand_control::DemandControlError; +use crate::plugins::demand_control::cost_calculator::static_cost::CostBySubgraph; use crate::plugins::demand_control::cost_calculator::static_cost::StaticCostCalculator; use crate::plugins::demand_control::strategy::StrategyImpl; use crate::services::execution; @@ -11,9 +16,17 @@ use crate::services::subgraph; pub(crate) struct StaticEstimated { // The estimated value of the demand pub(crate) max: f64, + pub(crate) subgraph_maxes: Arc>>, + pub(crate) actual_cost_computation_mode: ActualCostComputationMode, pub(crate) cost_calculator: StaticCostCalculator, } +impl StaticEstimated { + fn subgraph_max(&self, subgraph_name: &str) -> Option { + *self.subgraph_maxes.get(subgraph_name) + } +} + impl StrategyImpl for StaticEstimated { fn on_execution_request(&self, request: &execution::Request) -> Result<(), DemandControlError> { self.cost_calculator @@ -21,12 +34,15 @@ impl StrategyImpl for StaticEstimated { &request.query_plan, &request.supergraph_request.body().variables, ) - .and_then(|cost| { + .and_then(|cost_by_subgraph| { + let cost = cost_by_subgraph.total(); request .context .insert_cost_strategy("static_estimated".to_string())?; - request.context.insert_cost_result("COST_OK".to_string())?; request.context.insert_estimated_cost(cost)?; + request + .context + .insert_estimated_cost_by_subgraph(cost_by_subgraph)?; if cost > self.max { let error = DemandControlError::EstimatedCostTooExpensive { @@ -38,40 +54,177 @@ impl StrategyImpl for StaticEstimated { .insert_cost_result(error.code().to_string())?; Err(error) } else { + request.context.insert_cost_result("COST_OK".to_string())?; Ok(()) } }) } - fn on_subgraph_request(&self, _request: &subgraph::Request) -> Result<(), DemandControlError> { - Ok(()) + /// Reject subgraph requests when the total subgraph cost exceeds the subgraph max. + fn on_subgraph_request(&self, request: &subgraph::Request) -> Result<(), DemandControlError> { + let subgraph_name = request.subgraph_name.clone(); + let subgraph_cost = request + .context + .get_estimated_cost_for_subgraph(&subgraph_name)?; + + if let Some(subgraph_cost) = subgraph_cost + && let Some(subgraph_max) = self.subgraph_max(&subgraph_name) + && subgraph_cost > subgraph_max + { + let error = DemandControlError::EstimatedSubgraphCostTooExpensive { + subgraph: subgraph_name.clone(), + estimated_cost: subgraph_cost, + max_cost: subgraph_max, + }; + request + .context + .insert_cost_by_subgraph_result(subgraph_name, error.code().to_string())?; + Err(error) + } else { + request + .context + .insert_cost_by_subgraph_result(subgraph_name, "COST_OK".to_string())?; + Ok(()) + } } + /// Determine actual cost for this specific subgraph request fn on_subgraph_response( &self, - _request: &ExecutableDocument, - _response: &subgraph::Response, + subgraph_name: String, + request: &ExecutableDocument, + response: &subgraph::Response, ) -> Result<(), DemandControlError> { + if !matches!( + self.actual_cost_computation_mode, + ActualCostComputationMode::BySubgraph + ) { + return Ok(()); + } + + let subgraph_response_body = response.response.body(); + let cost = self.cost_calculator.actual( + request, + subgraph_response_body, + &response + .context + .extensions() + .with_lock(|lock| lock.get().cloned()) + .unwrap_or_default(), + true, + )?; + + response + .context + .update_actual_cost_by_subgraph(CostBySubgraph::new(subgraph_name, cost))?; + Ok(()) } + /// Sum up each subgraph response cost to determine the total cost fn on_execution_response( &self, context: &crate::Context, request: &ExecutableDocument, response: &graphql::Response, ) -> Result<(), DemandControlError> { - if response.data.is_some() { - let cost = self.cost_calculator.actual( + if response.data.is_none() { + return Ok(()); + } + + let cost = match self.actual_cost_computation_mode { + ActualCostComputationMode::BySubgraph => context + .get_actual_cost_by_subgraph()? + .map_or(0.0, |cost| cost.total()), + #[allow(deprecated_in_future)] + ActualCostComputationMode::Legacy => self.cost_calculator.actual( request, response, &context .extensions() .with_lock(|lock| lock.get().cloned()) .unwrap_or_default(), - )?; - context.insert_actual_cost(cost)?; - } + false, + )?, + }; + + context.insert_actual_cost(cost)?; Ok(()) } } + +#[cfg(test)] +mod tests { + use tower::BoxError; + + use super::StaticEstimated; + use crate::plugins::demand_control::DemandControl; + use crate::plugins::demand_control::StrategyConfig; + use crate::plugins::test::PluginTestHarness; + + async fn load_config_and_extract_strategy( + config: &'static str, + ) -> Result { + let schema_str = + include_str!("../cost_calculator/fixtures/basic_supergraph_schema.graphql"); + let plugin = PluginTestHarness::::builder() + .config(config) + .schema(schema_str) + .build() + .await?; + + let StrategyConfig::StaticEstimated { + list_size, + max, + actual_cost_computation_mode, + ref subgraphs, + } = plugin.config.strategy + else { + panic!("must provide static_estimated config"); + }; + let strategy = plugin.strategy_factory.create_static_estimated_strategy( + list_size, + max, + actual_cost_computation_mode, + subgraphs, + ); + Ok(strategy) + } + + #[tokio::test] + async fn test_per_subgraph_configuration_inheritance() { + let config = include_str!("../fixtures/per_subgraph_inheritance.yaml"); + + let strategy = load_config_and_extract_strategy(config).await.unwrap(); + assert_eq!(strategy.subgraph_max("reviews").unwrap(), 2.0); + assert_eq!(strategy.subgraph_max("products").unwrap(), 5.0); + assert_eq!(strategy.subgraph_max("users").unwrap(), 5.0); + } + + #[tokio::test] + async fn test_per_subgraph_configuration_no_inheritance() { + let config = include_str!("../fixtures/per_subgraph_no_inheritance.yaml"); + + let strategy = load_config_and_extract_strategy(config).await.unwrap(); + assert_eq!(strategy.subgraph_max("reviews").unwrap(), 2.0); + assert!(strategy.subgraph_max("products").is_none()); + assert!(strategy.subgraph_max("users").is_none()); + } + + #[tokio::test] + async fn test_invalid_per_subgraph_configuration() { + let config = include_str!("../fixtures/invalid_per_subgraph.yaml"); + let strategy_result = load_config_and_extract_strategy(config).await; + + match strategy_result { + Ok(strategy) => { + eprintln!("{:?}", strategy.subgraph_maxes); + panic!("Expected error") + } + Err(err) => assert_eq!( + &err.to_string(), + "Maximum per-subgraph query cost for `products` is negative" + ), + }; + } +} diff --git a/apollo-router/src/plugins/demand_control/strategy/test.rs b/apollo-router/src/plugins/demand_control/strategy/test.rs index 3755b8a5d2..d4c9363317 100644 --- a/apollo-router/src/plugins/demand_control/strategy/test.rs +++ b/apollo-router/src/plugins/demand_control/strategy/test.rs @@ -52,6 +52,7 @@ impl StrategyImpl for Test { fn on_subgraph_response( &self, + _subgraph_name: String, _request: &ExecutableDocument, response: &Response, ) -> Result<(), DemandControlError> { diff --git a/apollo-router/src/plugins/mock_subgraphs/mod.rs b/apollo-router/src/plugins/mock_subgraphs/mod.rs index 37a37be5fa..09bb05f0e8 100644 --- a/apollo-router/src/plugins/mock_subgraphs/mod.rs +++ b/apollo-router/src/plugins/mock_subgraphs/mod.rs @@ -140,6 +140,18 @@ impl PluginPrivate for MockSubgraphsPlugin { .build() }; let response = response.body(body).unwrap(); + request + .context + .upsert( + "apollo::experimental_mock_subgraphs::subgraph_call_count", + |mut v: HashMap| { + let subgraph_value = + v.entry(request.subgraph_name.clone()).or_default(); + *subgraph_value += 1; + v + }, + ) + .unwrap(); Ok(subgraph::Response::new_from_response( response, request.context, diff --git a/apollo-router/src/plugins/telemetry/metrics/apollo/studio.rs b/apollo-router/src/plugins/telemetry/metrics/apollo/studio.rs index 1566c6b451..f8dd27f33e 100644 --- a/apollo-router/src/plugins/telemetry/metrics/apollo/studio.rs +++ b/apollo-router/src/plugins/telemetry/metrics/apollo/studio.rs @@ -11,6 +11,7 @@ use super::histogram::ListLengthHistogram; use crate::apollo_studio_interop::AggregatedExtendedReferenceStats; use crate::apollo_studio_interop::ExtendedReferenceStats; use crate::apollo_studio_interop::ReferencedEnums; +use crate::plugins::demand_control::cost_calculator::static_cost::CostBySubgraph; use crate::plugins::telemetry::apollo::LicensedOperationCountByType; use crate::plugins::telemetry::apollo_exporter::proto::reports::EnumStats; use crate::plugins::telemetry::apollo_exporter::proto::reports::InputFieldStats; @@ -374,6 +375,8 @@ pub(crate) struct SingleLimitsStats { pub(crate) strategy: Option, pub(crate) cost_estimated: Option, pub(crate) cost_actual: Option, + pub(crate) cost_by_subgraph_estimated: Option, + pub(crate) cost_by_subgraph_actual: Option, pub(crate) depth: u64, pub(crate) height: u64, pub(crate) alias_count: u64, @@ -622,6 +625,14 @@ mod test { strategy: Some("test".to_string()), cost_estimated: Some(10.0), cost_actual: Some(7.0), + cost_by_subgraph_estimated: Some(CostBySubgraph::new( + "products".to_string(), + 10.0, + )), + cost_by_subgraph_actual: Some(CostBySubgraph::new( + "products".to_string(), + 7.0, + )), depth: 2, height: 4, alias_count: 0, diff --git a/apollo-router/src/plugins/telemetry/mod.rs b/apollo-router/src/plugins/telemetry/mod.rs index e3b434ddf7..05da9767fb 100644 --- a/apollo-router/src/plugins/telemetry/mod.rs +++ b/apollo-router/src/plugins/telemetry/mod.rs @@ -1475,6 +1475,14 @@ impl Telemetry { strategy: strategy.and_then(|s| serde_json::to_string(&s.mode).ok()), cost_estimated: context.get_estimated_cost().ok().flatten(), cost_actual: context.get_actual_cost().ok().flatten(), + cost_by_subgraph_estimated: context + .get_estimated_cost_by_subgraph() + .ok() + .flatten(), + cost_by_subgraph_actual: context + .get_actual_cost_by_subgraph() + .ok() + .flatten(), // These limits are related to the Traffic Shaping feature, unrelated to the Demand Control plugin depth: query_limits.map_or(0, |ql| ql.depth as u64), diff --git a/apollo-router/tests/integration/demand_control.rs b/apollo-router/tests/integration/demand_control.rs new file mode 100644 index 0000000000..5c4b0a919b --- /dev/null +++ b/apollo-router/tests/integration/demand_control.rs @@ -0,0 +1,672 @@ +use apollo_router::Context; +use apollo_router::graphql; + +const CODE_OK: &str = "COST_OK"; +const CODE_TOO_EXPENSIVE: &str = "COST_ESTIMATED_TOO_EXPENSIVE"; +const CODE_SUBGRAPH_TOO_EXPENSIVE: &str = "SUBGRAPH_COST_ESTIMATED_TOO_EXPENSIVE"; + +fn get_strategy(context: &Context) -> String { + let field = "apollo::demand_control::strategy"; + context + .get::<_, String>(field) + .expect("can't deserialize") + .unwrap_or_else(|| panic!("context missing {field}")) +} + +fn get_result(context: &Context) -> String { + let field = "apollo::demand_control::result"; + context + .get::<_, String>(field) + .expect("can't deserialize") + .unwrap_or_else(|| panic!("context missing {field}")) +} + +fn get_result_by_subgraph(context: &Context) -> Option { + context + .get::<_, serde_json::Value>("apollo::demand_control::result_by_subgraph") + .expect("can't deserialize") +} + +fn get_actual_cost(context: &Context) -> Option { + context + .get::<_, f64>("apollo::demand_control::actual_cost") + .expect("can't deserialize") +} + +fn get_actual_cost_by_subgraph(context: &Context) -> Option { + context + .get::<_, serde_json::Value>("apollo::demand_control::actual_cost_by_subgraph") + .expect("can't deserialize") +} + +fn get_estimated_cost(context: &Context) -> Option { + context + .get::<_, f64>("apollo::demand_control::estimated_cost") + .expect("can't deserialize") +} + +fn get_estimated_cost_by_subgraph(context: &Context) -> Option { + context + .get::<_, serde_json::Value>("apollo::demand_control::estimated_cost_by_subgraph") + .expect("can't deserialize") +} + +fn get_subgraph_call_count(context: &Context) -> Option { + context + .get::<_, serde_json::Value>("apollo::experimental_mock_subgraphs::subgraph_call_count") + .expect("can't deserialize") +} + +fn estimated_too_expensive(error: &&graphql::Error) -> bool { + error + .extensions + .get("code") + .is_some_and(|code| code == CODE_TOO_EXPENSIVE) +} + +fn subgraph_estimated_too_expensive(error: &&graphql::Error) -> bool { + error + .extensions + .get("code") + .is_some_and(|code| code == CODE_SUBGRAPH_TOO_EXPENSIVE) +} + +// TODO: add tests for static_estimated as well. would be good to have tests showing the difference +// between actual costs, and asserting that estimated costs are the same + +mod basic_fragments_tests { + use apollo_router::TestHarness; + use apollo_router::services::supergraph; + use tokio_stream::StreamExt; + use tower::BoxError; + use tower::ServiceExt; + + use super::CODE_OK; + use super::CODE_SUBGRAPH_TOO_EXPENSIVE; + use super::CODE_TOO_EXPENSIVE; + use super::estimated_too_expensive; + use super::get_actual_cost; + use super::get_actual_cost_by_subgraph; + use super::get_estimated_cost; + use super::get_estimated_cost_by_subgraph; + use super::get_result; + use super::get_result_by_subgraph; + use super::get_strategy; + use super::get_subgraph_call_count; + use super::subgraph_estimated_too_expensive; + + fn schema() -> &'static str { + include_str!( + "../../src/plugins/demand_control/cost_calculator/fixtures/basic_supergraph_schema.graphql" + ) + } + + fn query() -> &'static str { + include_str!( + "../../src/plugins/demand_control/cost_calculator/fixtures/basic_fragments_query.graphql" + ) + } + + fn subgraphs() -> serde_json::Value { + serde_json::json!({ + "products": { + "query": { + "interfaceInstance1": {"__typename": "SecondObjectType", "field1": null, "field2": "hello"}, + "someUnion": {"__typename": "FirstObjectType", "innerList": []} + }, + } + }) + } + + async fn supergraph_service( + demand_control: serde_json::Value, + ) -> Result { + TestHarness::builder() + .schema(schema()) + .configuration_json(serde_json::json!({ + "include_subgraph_errors": {"all": true}, + "demand_control": demand_control, + "experimental_mock_subgraphs": subgraphs(), + }))? + .build_supergraph() + .await + } + + async fn query_supergraph_service( + demand_control: serde_json::Value, + ) -> Result { + let service = supergraph_service(demand_control).await?; + let request = supergraph::Request::fake_builder().query(query()).build()?; + service.oneshot(request).await + } + + #[tokio::test(flavor = "multi_thread")] + #[rstest::rstest] + async fn requests_within_max_are_accepted( + #[values(12.0, 15.0)] max_cost: f64, + ) -> Result<(), BoxError> { + // query total cost is 12.0; max_cost >= 12.0 should result in query being accepted + let demand_control = serde_json::json!({ + "enabled": true, + "mode": "enforce", + "strategy": { + "static_estimated": { + "list_size": 10, + "max": max_cost + } + } + }); + + let response = query_supergraph_service(demand_control).await?; + + let context = response.context; + let body = response.response.into_body().next().await.unwrap(); + + // estimates + assert_eq!(&get_strategy(&context), "static_estimated"); + assert_eq!(&get_result(&context), CODE_OK); + assert_eq!(get_estimated_cost(&context).unwrap(), 12.0); + + assert_eq!( + get_result_by_subgraph(&context).unwrap(), + serde_json::json!({ "products": CODE_OK }) + ); + assert_eq!( + get_estimated_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "products": 12.0 }) + ); + + // actuals + assert!(body.data.is_some()); + assert!(body.errors.is_empty()); + + let subgraph_call_count = get_subgraph_call_count(&context).unwrap(); + assert_eq!(subgraph_call_count["products"], 1); + + assert_eq!(get_actual_cost(&context).unwrap(), 2.0); + assert_eq!( + get_actual_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "products": 2.0 }) + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn requests_exceeding_max_are_rejected() -> Result<(), BoxError> { + let max_cost = 10.0; + + // query total cost is 12.0 for list_size = 10; since `max_cost` value is less than this, + // the response should be an error + let demand_control = serde_json::json!({ + "enabled": true, + "mode": "enforce", + "strategy": { + "static_estimated": { + "list_size": 10, + "max": max_cost + } + } + }); + + let response = query_supergraph_service(demand_control).await?; + + let context = response.context; + let body = response.response.into_body().next().await.unwrap(); + + // estimates + assert_eq!(&get_strategy(&context), "static_estimated"); + assert_eq!(&get_result(&context), CODE_TOO_EXPENSIVE); + assert_eq!(get_estimated_cost(&context).unwrap(), 12.0); + + assert_eq!( + get_estimated_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "products": 12.0 }) + ); + + // actuals + assert!(body.data.is_none()); + + let error = body.errors.iter().find(estimated_too_expensive).unwrap(); + assert_eq!(error.extensions["cost.estimated"], 12.0); + assert_eq!(error.extensions["cost.max"], max_cost); + + assert!(get_subgraph_call_count(&context).is_none()); + + assert!(get_actual_cost(&context).is_none()); + assert!(get_actual_cost_by_subgraph(&context).is_none()); + + Ok(()) + } + + #[tokio::test] + async fn requests_which_exceed_subgraph_limit_are_partially_accepted() -> Result<(), BoxError> { + // query checks products once; query should be accepted based on max but products subgraph + // should not be called. + let demand_control = serde_json::json!({ + "enabled": true, + "mode": "enforce", + "strategy": { + "static_estimated": { + "list_size": 10, + "max": 15, + "subgraphs": { + "all": {}, + "subgraphs": { + "products": { + "max": 10 + } + } + } + } + } + }); + + let response = query_supergraph_service(demand_control).await?; + + let context = response.context; + let body = response.response.into_body().next().await.unwrap(); + + // estimates + assert_eq!(&get_strategy(&context), "static_estimated"); + assert_eq!(&get_result(&context), CODE_OK); + assert_eq!(get_estimated_cost(&context).unwrap(), 12.0); + assert_eq!( + get_result_by_subgraph(&context).unwrap(), + serde_json::json!({ "products": CODE_SUBGRAPH_TOO_EXPENSIVE }) + ); + assert_eq!( + get_estimated_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "products": 12.0}) + ); + + // actuals + assert!(body.data.is_some()); + assert!(!body.errors.iter().any(|e| estimated_too_expensive(&e))); + + let error = body + .errors + .iter() + .find(subgraph_estimated_too_expensive) + .unwrap(); + assert_eq!(error.extensions["cost.subgraph.estimated"], 12.0); + assert_eq!(error.extensions["cost.subgraph.max"], 10.0); + + let subgraph_call_count = get_subgraph_call_count(&context).unwrap_or_default(); + assert!(subgraph_call_count.get("products").is_none()); + + assert_eq!(get_actual_cost(&context).unwrap(), 0.0); + assert!(get_actual_cost_by_subgraph(&context).is_none()); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + #[rstest::rstest] + #[case::new_cost_mode("by_subgraph", 2.0, true)] + #[case::legacy_cost_mode("legacy", 2.0, false)] + async fn actual_cost_computation_can_vary_based_on_mode( + #[case] computation_mode: &str, + #[case] expected_actual_cost: f64, + #[case] should_have_per_subgraph_actuals: bool, + ) -> Result<(), BoxError> { + // query total cost is 12.0; max_cost >= 12.0 should result in query being accepted + let demand_control = serde_json::json!({ + "enabled": true, + "mode": "enforce", + "strategy": { + "static_estimated": { + "list_size": 10, + "actual_cost_computation_mode": computation_mode, + "max": 20.0 + } + } + }); + + let response = query_supergraph_service(demand_control).await?; + + let context = response.context; + let body = response.response.into_body().next().await.unwrap(); + + // estimates + assert_eq!(&get_strategy(&context), "static_estimated"); + assert_eq!(&get_result(&context), CODE_OK); + assert_eq!(get_estimated_cost(&context).unwrap(), 12.0); + + assert_eq!( + get_result_by_subgraph(&context).unwrap(), + serde_json::json!({ "products": CODE_OK }) + ); + assert_eq!( + get_estimated_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "products": 12.0 }) + ); + + // actuals + assert!(body.data.is_some()); + assert!(body.errors.is_empty()); + + let subgraph_call_count = get_subgraph_call_count(&context).unwrap(); + assert_eq!(subgraph_call_count["products"], 1); + + assert_eq!(get_actual_cost(&context).unwrap(), expected_actual_cost); + assert_eq!( + get_actual_cost_by_subgraph(&context).is_some(), + should_have_per_subgraph_actuals + ); + + Ok(()) + } +} + +mod federated_ships_tests { + use apollo_router::TestHarness; + use apollo_router::services::supergraph; + use tokio_stream::StreamExt; + use tower::BoxError; + use tower::ServiceExt; + + use super::CODE_OK; + use super::CODE_SUBGRAPH_TOO_EXPENSIVE; + use super::CODE_TOO_EXPENSIVE; + use super::estimated_too_expensive; + use super::get_actual_cost; + use super::get_actual_cost_by_subgraph; + use super::get_estimated_cost; + use super::get_estimated_cost_by_subgraph; + use super::get_result; + use super::get_result_by_subgraph; + use super::get_strategy; + use super::get_subgraph_call_count; + use super::subgraph_estimated_too_expensive; + + fn schema() -> &'static str { + include_str!( + "../../src/plugins/demand_control/cost_calculator/fixtures/federated_ships_schema.graphql" + ) + } + + fn query() -> &'static str { + include_str!( + "../../src/plugins/demand_control/cost_calculator/fixtures/federated_ships_required_query.graphql" + ) + } + + fn subgraphs() -> serde_json::Value { + serde_json::json!({ + "vehicles": { + "query": { + "ships": [ + {"__typename": "Ship", "id": 1, "name": "Ship1", "owner": {"__typename": "User", "licenseNumber": 10},}, + {"__typename": "Ship", "id": 2, "name": "Ship2", "owner": {"__typename": "User", "licenseNumber": 11},}, + {"__typename": "Ship", "id": 3, "name": "Ship3", "owner": {"__typename": "User", "licenseNumber": 12},}, + ], + }, + "entities": [ + {"__typename": "Ship", "id": 1, "owner": {"addresses": [{"zipCode": 18263}]}, "registrationFee": 129.2}, + {"__typename": "Ship", "id": 2, "owner": {"addresses": [{"zipCode": 61027}]}, "registrationFee": 14.0}, + {"__typename": "Ship", "id": 3, "owner": {"addresses": [{"zipCode": 86204}]}, "registrationFee": 97.15}, + {"__typename": "Ship", "id": 1, "owner": null, "registrationFee": null}, + {"__typename": "Ship", "id": 2, "owner": null, "registrationFee": null}, + {"__typename": "Ship", "id": 3, "owner": null, "registrationFee": null}, + ] + }, + "users": { + "entities": [ + {"__typename": "User", "licenseNumber": 10, "addresses": [{"zipCode": 18263}]}, + {"__typename": "User", "licenseNumber": 11, "addresses": [{"zipCode": 61027}]}, + {"__typename": "User", "licenseNumber": 12, "addresses": [{"zipCode": 86204}]}, + ], + } + }) + } + + async fn supergraph_service( + demand_control: serde_json::Value, + ) -> Result { + TestHarness::builder() + .schema(schema()) + .configuration_json(serde_json::json!({ + "include_subgraph_errors": {"all": true}, + "demand_control": demand_control, + "experimental_mock_subgraphs": subgraphs(), + }))? + .build_supergraph() + .await + } + + async fn query_supergraph_service( + demand_control: serde_json::Value, + ) -> Result { + let service = supergraph_service(demand_control).await?; + let request = supergraph::Request::fake_builder().query(query()).build()?; + service.oneshot(request).await + } + + #[tokio::test(flavor = "multi_thread")] + #[rstest::rstest] + async fn requests_within_max_are_accepted( + #[values(10400.0, 10500.0)] max_cost: f64, + ) -> Result<(), BoxError> { + // query total cost is 10400 for list_size = 100; all `max_cost` values are geq than this, + // so the response should be OK + let demand_control = serde_json::json!({ + "enabled": true, + "mode": "enforce", + "strategy": { + "static_estimated": { + "list_size": 100, + "max": max_cost + } + } + }); + + let response = query_supergraph_service(demand_control).await?; + + let context = response.context; + let body = response.response.into_body().next().await.unwrap(); + + // estimates + assert_eq!(&get_strategy(&context), "static_estimated"); + assert_eq!(&get_result(&context), CODE_OK); + assert_eq!(get_estimated_cost(&context).unwrap(), 10400.0); + + assert_eq!( + get_result_by_subgraph(&context).unwrap(), + serde_json::json!({ "users": CODE_OK, "vehicles": CODE_OK }) + ); + assert_eq!( + get_estimated_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "users": 10100.0, "vehicles": 300.0 }) + ); + + // actuals + assert!(body.data.is_some()); + assert!(body.errors.is_empty()); + + let subgraph_call_count = get_subgraph_call_count(&context).unwrap(); + assert_eq!(subgraph_call_count["users"], 1); + assert_eq!(subgraph_call_count["vehicles"], 2); + + assert_eq!(get_actual_cost(&context).unwrap(), 15.0); + assert_eq!( + get_actual_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "users": 6.0, "vehicles": 9.0 }) + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + async fn requests_exceeding_max_are_rejected() -> Result<(), BoxError> { + let max_cost = 10000.0; + + // query total cost is 10400 for list_size = 100; since `max_cost` value is less than this, + // the response should be an error + let demand_control = serde_json::json!({ + "enabled": true, + "mode": "enforce", + "strategy": { + "static_estimated": { + "list_size": 100, + "max": max_cost + } + } + }); + + let response = query_supergraph_service(demand_control).await?; + + let context = response.context; + let body = response.response.into_body().next().await.unwrap(); + + // estimates + assert_eq!(&get_strategy(&context), "static_estimated"); + assert_eq!(&get_result(&context), CODE_TOO_EXPENSIVE); + assert_eq!(get_estimated_cost(&context).unwrap(), 10400.0); + + assert_eq!( + get_estimated_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "users": 10100.0, "vehicles": 300.0 }) + ); + + // actuals + assert!(body.data.is_none()); + + let error = body.errors.iter().find(estimated_too_expensive).unwrap(); + assert_eq!(error.extensions["cost.estimated"], 10400.0); + assert_eq!(error.extensions["cost.max"], max_cost); + + assert!(get_subgraph_call_count(&context).is_none()); + + assert!(get_actual_cost(&context).is_none()); + assert!(get_actual_cost_by_subgraph(&context).is_none()); + + Ok(()) + } + + #[tokio::test] + async fn requests_which_exceed_subgraph_limit_are_partially_accepted() -> Result<(), BoxError> { + // query checks vehicles, then users, then vehicles. + // interrupting the users check via a demand control limit should still permit both vehicles + // checks. + let demand_control = serde_json::json!({ + "enabled": true, + "mode": "enforce", + "strategy": { + "static_estimated": { + "list_size": 100, + "max": 15000.0, + "subgraphs": { + "all": {}, + "subgraphs": { + "users": { + "max": 0 + } + } + } + } + } + }); + + let response = query_supergraph_service(demand_control).await?; + + let context = response.context; + let body = response.response.into_body().next().await.unwrap(); + + // estimates + assert_eq!(&get_strategy(&context), "static_estimated"); + assert_eq!(&get_result(&context), CODE_OK); + assert_eq!(get_estimated_cost(&context).unwrap(), 10400.0); + assert_eq!( + get_result_by_subgraph(&context).unwrap(), + serde_json::json!({ "users": CODE_SUBGRAPH_TOO_EXPENSIVE, "vehicles": CODE_OK }) + ); + assert_eq!( + get_estimated_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "users": 10100.0, "vehicles": 300.0 }) + ); + + // actuals + assert!(body.data.is_some()); + assert!(!body.errors.iter().any(|e| estimated_too_expensive(&e))); + + let error = body + .errors + .iter() + .find(subgraph_estimated_too_expensive) + .unwrap(); + assert_eq!(error.extensions["cost.subgraph.estimated"], 10100.0); + assert_eq!(error.extensions["cost.subgraph.max"], 0.0); + + let subgraph_call_count = get_subgraph_call_count(&context).unwrap(); + assert!(subgraph_call_count.get("users").is_none()); + assert_eq!(subgraph_call_count["vehicles"], 2); + + assert_eq!(get_actual_cost(&context).unwrap(), 9.0); + assert_eq!( + get_actual_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "vehicles": 9.0 }) + ); + + Ok(()) + } + + #[tokio::test(flavor = "multi_thread")] + #[rstest::rstest] + #[case::new_cost_mode("by_subgraph", 15.0, true)] + #[case::legacy_cost_mode("legacy", 3.0, false)] + async fn actual_cost_computation_can_vary_based_on_mode( + #[case] computation_mode: &str, + #[case] expected_actual_cost: f64, + #[case] should_have_per_subgraph_actuals: bool, + ) -> Result<(), BoxError> { + // estimated cost is 140 because list_size is 10, in contrast to tests above which use + // list_size 100 and therefore have estimated cost 10400 + let demand_control = serde_json::json!({ + "enabled": true, + "mode": "enforce", + "strategy": { + "static_estimated": { + "list_size": 10, + "actual_cost_computation_mode": computation_mode, + "max": 150.0 + } + } + }); + + let response = query_supergraph_service(demand_control).await?; + + let context = response.context; + let body = response.response.into_body().next().await.unwrap(); + + // estimates + assert_eq!(&get_strategy(&context), "static_estimated"); + assert_eq!(&get_result(&context), CODE_OK); + assert_eq!(get_estimated_cost(&context).unwrap(), 140.0); + + assert_eq!( + get_result_by_subgraph(&context).unwrap(), + serde_json::json!({ "users": CODE_OK, "vehicles": CODE_OK }) + ); + assert_eq!( + get_estimated_cost_by_subgraph(&context).unwrap(), + serde_json::json!({ "users": 110.0, "vehicles": 30.0 }) + ); + + // actuals + assert!(body.data.is_some()); + assert!(body.errors.is_empty()); + + let subgraph_call_count = get_subgraph_call_count(&context).unwrap(); + assert_eq!(subgraph_call_count["users"], 1); + assert_eq!(subgraph_call_count["vehicles"], 2); + + assert_eq!(get_actual_cost(&context).unwrap(), expected_actual_cost); + assert_eq!( + get_actual_cost_by_subgraph(&context).is_some(), + should_have_per_subgraph_actuals + ); + + Ok(()) + } +} diff --git a/apollo-router/tests/integration/mod.rs b/apollo-router/tests/integration/mod.rs index dc554f1e6f..aa81b9bdc5 100644 --- a/apollo-router/tests/integration/mod.rs +++ b/apollo-router/tests/integration/mod.rs @@ -10,6 +10,7 @@ pub(crate) mod redis_monitor; mod allowed_features; mod connectors; mod coprocessor; +mod demand_control; mod docs; // In the CI environment we only install Redis on x86_64 Linux mod directives; diff --git a/apollo-router/tests/snapshots/apollo_reports__demand_control_stats.snap b/apollo-router/tests/snapshots/apollo_reports__demand_control_stats.snap index a581f9baf7..628bf2eba5 100644 --- a/apollo-router/tests/snapshots/apollo_reports__demand_control_stats.snap +++ b/apollo-router/tests/snapshots/apollo_reports__demand_control_stats.snap @@ -1,7 +1,6 @@ --- source: apollo-router/tests/apollo_reports.rs expression: report -snapshot_kind: text --- header: graph_ref: test @@ -203,9 +202,8 @@ traces_per_query: - 0 - 0 - 0 - - 0 - 1 - max_cost_actual: 20 + max_cost_actual: 18 depth: 4 height: 7 alias_count: 0 diff --git a/apollo-router/tests/snapshots/apollo_reports__demand_control_trace-2.snap b/apollo-router/tests/snapshots/apollo_reports__demand_control_trace-2.snap index 2419a5d911..c83e2c5b02 100644 --- a/apollo-router/tests/snapshots/apollo_reports__demand_control_trace-2.snap +++ b/apollo-router/tests/snapshots/apollo_reports__demand_control_trace-2.snap @@ -1,7 +1,6 @@ --- source: apollo-router/tests/apollo_reports.rs expression: report -snapshot_kind: text --- header: graph_ref: test @@ -605,7 +604,7 @@ traces_per_query: result: COST_OK strategy: static_estimated cost_estimated: 230 - cost_actual: 20 + cost_actual: 18 depth: 4 height: 7 alias_count: 0 diff --git a/apollo-router/tests/snapshots/apollo_reports__demand_control_trace.snap b/apollo-router/tests/snapshots/apollo_reports__demand_control_trace.snap index 2419a5d911..c83e2c5b02 100644 --- a/apollo-router/tests/snapshots/apollo_reports__demand_control_trace.snap +++ b/apollo-router/tests/snapshots/apollo_reports__demand_control_trace.snap @@ -1,7 +1,6 @@ --- source: apollo-router/tests/apollo_reports.rs expression: report -snapshot_kind: text --- header: graph_ref: test @@ -605,7 +604,7 @@ traces_per_query: result: COST_OK strategy: static_estimated cost_estimated: 230 - cost_actual: 20 + cost_actual: 18 depth: 4 height: 7 alias_count: 0 diff --git a/apollo-router/tests/snapshots/apollo_reports__demand_control_trace_batched-2.snap b/apollo-router/tests/snapshots/apollo_reports__demand_control_trace_batched-2.snap index 9eaa200d72..ee75edaba1 100644 --- a/apollo-router/tests/snapshots/apollo_reports__demand_control_trace_batched-2.snap +++ b/apollo-router/tests/snapshots/apollo_reports__demand_control_trace_batched-2.snap @@ -1,7 +1,6 @@ --- source: apollo-router/tests/apollo_reports.rs expression: report -snapshot_kind: text --- header: graph_ref: test @@ -605,7 +604,7 @@ traces_per_query: result: COST_OK strategy: static_estimated cost_estimated: 230 - cost_actual: 20 + cost_actual: 18 depth: 4 height: 7 alias_count: 0 @@ -1205,7 +1204,7 @@ traces_per_query: result: COST_OK strategy: static_estimated cost_estimated: 230 - cost_actual: 20 + cost_actual: 18 depth: 4 height: 7 alias_count: 0 diff --git a/apollo-router/tests/snapshots/apollo_reports__demand_control_trace_batched.snap b/apollo-router/tests/snapshots/apollo_reports__demand_control_trace_batched.snap index 9eaa200d72..ee75edaba1 100644 --- a/apollo-router/tests/snapshots/apollo_reports__demand_control_trace_batched.snap +++ b/apollo-router/tests/snapshots/apollo_reports__demand_control_trace_batched.snap @@ -1,7 +1,6 @@ --- source: apollo-router/tests/apollo_reports.rs expression: report -snapshot_kind: text --- header: graph_ref: test @@ -605,7 +604,7 @@ traces_per_query: result: COST_OK strategy: static_estimated cost_estimated: 230 - cost_actual: 20 + cost_actual: 18 depth: 4 height: 7 alias_count: 0 @@ -1205,7 +1204,7 @@ traces_per_query: result: COST_OK strategy: static_estimated cost_estimated: 230 - cost_actual: 20 + cost_actual: 18 depth: 4 height: 7 alias_count: 0