Skip to content

Commit

Permalink
Fix(GraphQL): Fix Aggregate queries on empty data (#7119)
Browse files Browse the repository at this point in the history
* Fix aggregate queries on empty data

* Empty commit to make the CLA pass
  • Loading branch information
vmrajas authored Dec 11, 2020
1 parent 61ff62f commit 8b8b0ad
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 40 deletions.
14 changes: 2 additions & 12 deletions graphql/e2e/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1620,12 +1620,7 @@ func TestChildAggregateQueryWithDeepRBAC(t *testing.T) {
[
{
"username": "user1",
"issuesAggregate":
{
"count": null,
"msgMax": null,
"msgMin": null
}
"issuesAggregate": null
}
]
}`},
Expand Down Expand Up @@ -1687,12 +1682,7 @@ func TestChildAggregateQueryWithOtherFields(t *testing.T) {
{
"username": "user1",
"issues":[],
"issuesAggregate":
{
"count": null,
"msgMin": null,
"msgMax": null
}
"issuesAggregate": null
}
]
}`},
Expand Down
2 changes: 2 additions & 0 deletions graphql/e2e/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,12 @@ func RunAll(t *testing.T) {
t.Run("persisted query", persistedQuery)
t.Run("query aggregate without filter", queryAggregateWithoutFilter)
t.Run("query aggregate with filter", queryAggregateWithFilter)
t.Run("query aggregate on empty data", queryAggregateOnEmptyData)
t.Run("query aggregate with alias", queryAggregateWithAlias)
t.Run("query aggregate with repeated fields", queryAggregateWithRepeatedFields)
t.Run("query aggregate at child level", queryAggregateAtChildLevel)
t.Run("query aggregate at child level with filter", queryAggregateAtChildLevelWithFilter)
t.Run("query aggregate at child level with empty data", queryAggregateAtChildLevelWithEmptyData)
t.Run("query aggregate at child level with multiple alias", queryAggregateAtChildLevelWithMultipleAlias)
t.Run("query aggregate at child level with repeated fields", queryAggregateAtChildLevelWithRepeatedFields)
t.Run("query aggregate and other fields at child level", queryAggregateAndOtherFieldsAtChildLevel)
Expand Down
39 changes: 31 additions & 8 deletions graphql/e2e/common/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2833,8 +2833,10 @@ func queryAggregateWithFilter(t *testing.T) {
}
}`,
string(gqlResponse.Data))
}

queryPostParams = &GraphQLParams{
func queryAggregateOnEmptyData(t *testing.T) {
queryPostParams := &GraphQLParams{
Query: `query {
aggregatePost (filter: {title : { anyofterms : "Nothing" }} ) {
count
Expand All @@ -2844,16 +2846,11 @@ func queryAggregateWithFilter(t *testing.T) {
}`,
}

gqlResponse = queryPostParams.ExecuteAsPost(t, GraphqlURL)
gqlResponse := queryPostParams.ExecuteAsPost(t, GraphqlURL)
RequireNoGQLErrors(t, gqlResponse)
testutil.CompareJSON(t,
`{
"aggregatePost":
{
"count":0,
"numLikesMax": 0,
"titleMin": "0.000000"
}
"aggregatePost": null
}`,
string(gqlResponse.Data))
}
Expand Down Expand Up @@ -3011,6 +3008,32 @@ func queryAggregateAtChildLevelWithFilter(t *testing.T) {
string(gqlResponse.Data))
}

func queryAggregateAtChildLevelWithEmptyData(t *testing.T) {
queryNumberOfIndianStates := &GraphQLParams{
Query: `query
{
queryCountry(filter: { name: { eq: "India" } }) {
name
ag : statesAggregate(filter: {xcode: {in: ["nothing"]}}) {
count
nameMin
}
}
}`,
}
gqlResponse := queryNumberOfIndianStates.ExecuteAsPost(t, GraphqlURL)
RequireNoGQLErrors(t, gqlResponse)
testutil.CompareJSON(t,
`
{
"queryCountry": [{
"name": "India",
"ag": null
}]
}`,
string(gqlResponse.Data))
}

func queryAggregateAtChildLevelWithMultipleAlias(t *testing.T) {
queryNumberOfIndianStates := &GraphQLParams{
Query: `query
Expand Down
5 changes: 3 additions & 2 deletions graphql/resolve/auth_query_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -617,13 +617,13 @@
dgquery: |-
query {
aggregateProject() {
nameMin : min(val(nameVar))
count : max(val(countVar))
nameMin : min(val(nameVar))
randomMin : min(val(randomVar))
}
var(func: uid(ProjectRoot)) {
nameVar as Project.name
countVar as count(uid)
nameVar as Project.name
randomVar as Project.random
}
ProjectRoot as var(func: uid(Project1))
Expand Down Expand Up @@ -1194,6 +1194,7 @@
ticketsAggregate_titleVar as Ticket.title
dgraph.uid : uid
}
count_ticketsAggregate : count(User.tickets) @filter(uid(TicketAggregateResult1))
titleMin_ticketsAggregate : min(val(ticketsAggregate_titleVar))
issuesAggregate : User.issues @filter(uid(IssueAggregateResult4)) {
issuesAggregate_msgVar as Issue.msg
Expand Down
48 changes: 30 additions & 18 deletions graphql/resolve/query_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,23 +191,30 @@ func aggregateQuery(query schema.Query, authRw *authRewriter) []*gql.GraphQuery
// isAggregateFunctionVisited stores if the aggregate function for a field has been added or not.
// So the map entries would contain keys as nameMin, ageMin, nameName, etc.
isAggregateFunctionVisited := make(map[string]bool)

// Add count field to aggregateQuery by default. This is done to ensure that null is
// returned in case the count of nodes is 0.
child := &gql.GraphQuery{
Var: "countVar",
Attr: "count(uid)",
}
finalQueryChild := &gql.GraphQuery{
Alias: "count",
Attr: "max(val(countVar))",
}
mainQuery.Children = append(mainQuery.Children, child)
finalMainQuery.Children = append(finalMainQuery.Children, finalQueryChild)

for _, f := range query.SelectionSet() {
// fldName stores Name of the field f.
fldName := f.Name()
if _, visited := isAggregateFunctionVisited[fldName]; visited {
continue
}
isAggregateFunctionVisited[fldName] = true
if fldName == "count" {
child := &gql.GraphQuery{
Var: "countVar",
Attr: "count(uid)",
}
finalQueryChild := &gql.GraphQuery{
Alias: fldName,
Attr: "max(val(countVar))",
}
mainQuery.Children = append(mainQuery.Children, child)
finalMainQuery.Children = append(finalMainQuery.Children, finalQueryChild)
// We continue in case of a count field in Aggregate Query as count has already
// been added by default just before the for loop.
continue
}

Expand Down Expand Up @@ -1038,6 +1045,16 @@ func buildAggregateFields(
// contain "scoreVar as Tweets.score" only once.
isAggregateFieldVisited := make(map[string]bool)

// Add the default count field. Count field is part of an AggregateField by default
// as this makes it possible to return null field in case the count of nodes is 0
aggregateChild := &gql.GraphQuery{
Alias: "count_" + fieldAlias,
Attr: "count(" + constructedForDgraphPredicate + ")",
}
// Add filter to count aggregation field.
_ = addFilter(aggregateChild, constructedForType, fieldFilter)
aggregateChildren = append(aggregateChildren, aggregateChild)

// Iterate over fields queried inside aggregate.
for _, aggregateField := range f.SelectionSet() {
// Don't add the same field twice
Expand All @@ -1046,15 +1063,10 @@ func buildAggregateFields(
}
addedAggregateField[aggregateField.DgraphAlias()] = true

// Handle count fields inside aggregate fields.
// As count fields are always part of an AggregateField by
// default (added just before this for loop). We continue
// in case of a count field.
if aggregateField.DgraphAlias() == "count" {
aggregateChild := &gql.GraphQuery{
Alias: "count_" + fieldAlias,
Attr: "count(" + constructedForDgraphPredicate + ")",
}
// Add filter to count aggregation field.
_ = addFilter(aggregateChild, constructedForType, fieldFilter)
aggregateChildren = append(aggregateChildren, aggregateChild)
continue
}
// Handle other aggregate functions than count
Expand Down
65 changes: 65 additions & 0 deletions graphql/resolve/query_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,29 @@
}
}
-
name: "Aggregate Query with no count field"
gqlquery: |
query {
aggregateCountry(filter: { name: { regexp: "/.*ust.*/" }}) {
nameMin
nm : nameMin
nameMax
}
}
dgquery: |-
query {
aggregateCountry() {
count : max(val(countVar))
nameMin : min(val(nameVar))
nameMax : max(val(nameVar))
}
var(func: type(Country)) @filter(regexp(Country.name, /.*ust.*/)) {
countVar as count(uid)
nameVar as Country.name
}
}
-
name: "Skip directive"
variables:
Expand Down Expand Up @@ -2956,6 +2979,48 @@
statesAggregate_nameVar as State.name
dgraph.uid : uid
}
count_statesAggregate : count(Country.states)
nameMin_statesAggregate : min(val(statesAggregate_nameVar))
nameMax_statesAggregate : max(val(statesAggregate_nameVar))
statesAggregate1 : Country.states @filter(eq(State.code, "state code")) {
statesAggregate1_nameVar as State.name
statesAggregate1_capitalVar as State.capital
dgraph.uid : uid
}
count_statesAggregate1 : count(Country.states) @filter(eq(State.code, "state code"))
nameMin_statesAggregate1 : min(val(statesAggregate1_nameVar))
nameMax_statesAggregate1 : max(val(statesAggregate1_nameVar))
capitalMin_statesAggregate1 : min(val(statesAggregate1_capitalVar))
dgraph.uid : uid
}
}
-
name: "Aggregate query at child level with no count field"
gqlquery: |
query {
queryCountry {
nm : name
ag : statesAggregate {
nMin : nameMin
nMax : nameMax
}
statesAggregate(filter: { code: { eq: "state code" } }) {
nMin : nameMin
nMax : nameMax
cMin : capitalMin
}
}
}
dgquery: |-
query {
queryCountry(func: type(Country)) {
name : Country.name
statesAggregate : Country.states {
statesAggregate_nameVar as State.name
dgraph.uid : uid
}
count_statesAggregate : count(Country.states)
nameMin_statesAggregate : min(val(statesAggregate_nameVar))
nameMax_statesAggregate : max(val(statesAggregate_nameVar))
statesAggregate1 : Country.states @filter(eq(State.code, "state code")) {
Expand Down
29 changes: 29 additions & 0 deletions graphql/resolve/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,35 @@ func completeObject(
}}
}
}

// Handle the case of empty data in Aggregate Queries. If count of data is equal
// to 0, set the val map to nil. This makes the aggregateField return null instead
// of returning "0.0000" for Min, Max function on strings and 0 for Min, Max functions
// on integers/float.
if strings.HasSuffix(f.Type().Name(), "AggregateResult") && val != nil {
var count json.Number
countVal := val.(map[string]interface{})["count"]
if countVal == nil {
// This case may happen in case of auth queries when the user does not have
// sufficient permission to query aggregate fields. We set val to nil in this
// case
val = nil
} else {
if count, ok = countVal.(json.Number); !ok {
// This is to handle case in which countVal is of any other type than
// json.Number. This should never happen. We return an error.
return nil, x.GqlErrorList{&x.GqlError{
Message: "Expected count field of type json.Number inside Aggregate Field",
Locations: []x.Location{f.Location()},
Path: copyPath(path),
}}
}
if count == "0" {
val = nil
}
}
}

completed, err := completeValue(append(path, f.ResponseName()), f, val)
errs = append(errs, err...)
if completed == nil {
Expand Down

0 comments on commit 8b8b0ad

Please sign in to comment.