diff --git a/dgraph/cmd/alpha/upsert_test.go b/dgraph/cmd/alpha/upsert_test.go new file mode 100644 index 00000000000..8940390a971 --- /dev/null +++ b/dgraph/cmd/alpha/upsert_test.go @@ -0,0 +1,100 @@ +/* + * Copyright 2019 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package alpha + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +// contains checks whether given element is contained +// in any of the elements of the given list of strings. +func contains(ps []string, p string) bool { + var res bool + for _, v := range ps { + res = res || strings.Contains(v, p) + } + + return res +} + +func TestUpsertExample0(t *testing.T) { + require.NoError(t, dropAll()) + require.NoError(t, alterSchema(`email: string @index(exact) .`)) + + // Mutation with wrong name + m1 := ` +upsert { + mutation { + set { + uid(v) "Wrong" . + uid(v) "ashish@dgraph.io" . + } + } + + query { + me(func: eq(email, "ashish@dgraph.io")) { + v as uid + } + } +}` + keys, preds, _, err := mutationWithTs(m1, "application/rdf", false, true, true, 0) + require.NoError(t, err) + require.True(t, len(keys) == 0) + require.True(t, contains(preds, "email")) + require.True(t, contains(preds, "name")) + + // query should return the wrong name + q1 := ` +{ + q(func: has(email)) { + uid + name + email + } +}` + res, _, err := queryWithTs(q1, "application/graphql+-", 0) + require.NoError(t, err) + require.Contains(t, res, "Wrong") + + // mutation with correct name + m2 := ` +upsert { + mutation { + set { + uid(v) "Ashish" . + } + } + + query { + me(func: eq(email, "ashish@dgraph.io")) { + v as uid + } + } +}` + keys, preds, _, err = mutationWithTs(m2, "application/rdf", false, true, true, 0) + require.NoError(t, err) + require.True(t, len(keys) == 0) + require.True(t, contains(preds, "name")) + + // query should return correct name + res, _, err = queryWithTs(q1, "application/graphql+-", 0) + require.NoError(t, err) + require.Contains(t, res, "Ashish") +} diff --git a/edgraph/server.go b/edgraph/server.go index dfa9981f87d..de538b93a56 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -501,6 +501,36 @@ func (s *Server) doMutate(ctx context.Context, mu *api.Mutation, authorize bool) } }() + needVars := findVars(gmu) + varToUID, err := doQueryInUpsert(ctx, &l, mu.Query, needVars, mu.StartTs) + if err != nil { + return resp, err + } + + if mu.Query != "" { + // does following transformations: + // * uid(v) -> 0x123 -- If v is defined in query block + // * uid(v) -> _:uid(v) -- Otherwise + getNewVal := func(s string) string { + if strings.HasPrefix(s, "uid(") { + varName := s[4 : len(s)-1] + if uid, ok := varToUID[varName]; ok { + return uid + } + + return "_:" + s + } + + return s + } + + // update the values in mutation block from the query block. + for _, nq := range append(gmu.Set, gmu.Del...) { + nq.Subject = getNewVal(nq.Subject) + nq.ObjectId = getNewVal(nq.ObjectId) + } + } + newUids, err := query.AssignUids(ctx, gmu.Set) if err != nil { return resp, err @@ -560,6 +590,89 @@ func (s *Server) doMutate(ctx context.Context, mu *api.Mutation, authorize bool) return resp, nil } +// findVars finds all the variables used in mutation block +func findVars(gmu *gql.Mutation) []string { + vars := make(map[string]struct{}) + updateVars := func(s string) { + if strings.HasPrefix(s, "uid(") { + varName := s[4 : len(s)-1] + vars[varName] = struct{}{} + } + } + for _, nq := range gmu.Set { + updateVars(nq.Subject) + updateVars(nq.ObjectId) + } + for _, nq := range gmu.Del { + updateVars(nq.Subject) + updateVars(nq.ObjectId) + } + + varsList := make([]string, 0, len(vars)) + for v := range vars { + varsList = append(varsList, v) + } + if glog.V(3) { + glog.Infof("Variables used in mutation block: %v", varsList) + } + + return varsList +} + +// doQueryInUpsert processes a query in the upsert block. +// TODO(Aman): refactor this function along with doMutate +func doQueryInUpsert(ctx context.Context, l *query.Latency, queryText string, + needVars []string, startTs uint64) (map[string]string, error) { + + varToUID := make(map[string]string) + if queryText == "" { + return varToUID, nil + } + + if startTs == 0 { + return nil, errors.Errorf("Transaction timestamp is zero") + } + + parsedReq, err := gql.ParseWithNeedVars(gql.Request{ + Str: queryText, + Variables: make(map[string]string), + }, needVars) + if err != nil { + return nil, err + } + if err = validateQuery(parsedReq.Query); err != nil { + return nil, err + } + + qr := query.Request{ + Latency: l, + GqlQuery: &parsedReq, + ReadTs: startTs, + } + if err := qr.ProcessQuery(ctx); err != nil { + return nil, errors.Wrapf(err, "while processing query: %q", queryText) + } + + if len(qr.Vars) <= 0 { + return nil, fmt.Errorf("upsert query op has no variables") + } + + // TODO(Aman): allow multiple values for each variable. + // If a variable doesn't have any UID, we generate one ourselves later. + for name, v := range qr.Vars { + if v.Uids == nil { + continue + } + if len(v.Uids.Uids) > 1 { + return nil, fmt.Errorf("more than one values found for var (%s)", name) + } else if len(v.Uids.Uids) == 1 { + varToUID[name] = fmt.Sprintf("%d", v.Uids.Uids[0]) + } + } + + return varToUID, nil +} + // Query handles queries and returns the data. func (s *Server) Query(ctx context.Context, req *api.Request) (*api.Response, error) { if err := authorizeQuery(ctx, req); err != nil { diff --git a/gql/parser.go b/gql/parser.go index 4a7de40f181..6f09fa8224b 100644 --- a/gql/parser.go +++ b/gql/parser.go @@ -555,6 +555,12 @@ func ParseWithNeedVars(r Request, needVars []string) (res Result, rerr error) { } allVars := res.QueryVars + // Add the variables that are needed outside the query block. + // For example, mutation block in upsert block will be using + // variables from the query block that is getting parsed here. + if len(needVars) != 0 { + allVars = append(allVars, &Vars{Needs: needVars}) + } if err := checkDependency(allVars); err != nil { return res, err } diff --git a/query/query.go b/query/query.go index d4ba0fc9b91..a83e89b7661 100644 --- a/query/query.go +++ b/query/query.go @@ -1491,7 +1491,7 @@ AssignStep: // Updates the doneVars map by picking up uid/values from the current Subgraph func (sg *SubGraph) updateVars(doneVars map[string]varValue, sgPath []*SubGraph) error { - // NOTE: although we initialize doneVars (req.vars) in ProcessQuery, this nil check is for + // NOTE: although we initialize doneVars (req.Vars) in ProcessQuery, this nil check is for // non-root lookups that happen to other nodes. Don't use len(doneVars) == 0 ! if doneVars == nil || (sg.Params.Var == "" && sg.Params.FacetVar == nil) { return nil @@ -2593,7 +2593,7 @@ type Request struct { Subgraphs []*SubGraph - vars map[string]varValue + Vars map[string]varValue } // ProcessQuery processes query part of the request (without mutations). @@ -2605,7 +2605,7 @@ func (req *Request) ProcessQuery(ctx context.Context) (err error) { defer stop() // doneVars stores the processed variables. - req.vars = make(map[string]varValue) + req.Vars = make(map[string]varValue) loopStart := time.Now() queries := req.GqlQuery.Query for i := 0; i < len(queries); i++ { @@ -2647,7 +2647,7 @@ func (req *Request) ProcessQuery(ctx context.Context) (err error) { } // The variable should be defined in this block or should have already been // populated by some other block, otherwise we are not ready to execute yet. - _, ok := req.vars[v] + _, ok := req.Vars[v] if !ok && !selfDep { return false } @@ -2671,7 +2671,7 @@ func (req *Request) ProcessQuery(ctx context.Context) (err error) { continue } - err = sg.recursiveFillVars(req.vars) + err = sg.recursiveFillVars(req.Vars) if err != nil { return err } @@ -2719,10 +2719,10 @@ func (req *Request) ProcessQuery(ctx context.Context) (err error) { sg := req.Subgraphs[idx] var sgPath []*SubGraph - if err := sg.populateVarMap(req.vars, sgPath); err != nil { + if err := sg.populateVarMap(req.Vars, sgPath); err != nil { return err } - if err := sg.populatePostAggregation(req.vars, []*SubGraph{}, nil); err != nil { + if err := sg.populatePostAggregation(req.Vars, []*SubGraph{}, nil); err != nil { return err } }