diff --git a/dgraph/cmd/alpha/http.go b/dgraph/cmd/alpha/http.go index 7725e8503e4..d3bef5bd339 100644 --- a/dgraph/cmd/alpha/http.go +++ b/dgraph/cmd/alpha/http.go @@ -21,6 +21,7 @@ import ( "compress/gzip" "context" "encoding/json" + "fmt" "io" "io/ioutil" "net/http" @@ -238,6 +239,8 @@ func queryHandler(w http.ResponseWriter, r *http.Request) { x.SetStatusWithData(w, x.ErrorInvalidRequest, err.Error()) return } + // Add cost to the header. + w.Header().Set(x.DgraphCostHeader, fmt.Sprint(resp.Metrics.NumUids["_total"])) e := query.Extensions{ Txn: resp.Txn, @@ -390,6 +393,8 @@ func mutationHandler(w http.ResponseWriter, r *http.Request) { x.SetStatusWithData(w, x.ErrorInvalidRequest, err.Error()) return } + // Add cost to the header. + w.Header().Set(x.DgraphCostHeader, fmt.Sprint(resp.Metrics.NumUids["_total"])) resp.Latency.ParsingNs = uint64(parseEnd.Sub(parseStart).Nanoseconds()) e := query.Extensions{ diff --git a/dgraph/cmd/alpha/http_test.go b/dgraph/cmd/alpha/http_test.go index b39446a05d3..8d5316fbd9e 100644 --- a/dgraph/cmd/alpha/http_test.go +++ b/dgraph/cmd/alpha/http_test.go @@ -196,6 +196,38 @@ func queryWithTs(queryText, contentType, debug string, ts uint64) (string, uint6 return string(output), startTs, err } +// queryWithTsForResp query the dgraph and returns it's http response and result. +func queryWithTsForResp(queryText, contentType, debug string, ts uint64) (string, + uint64, *http.Response, error) { + params := make([]string, 0, 2) + if debug != "" { + params = append(params, "debug="+debug) + } + if ts != 0 { + params = append(params, fmt.Sprintf("startTs=%v", strconv.FormatUint(ts, 10))) + } + url := addr + "/query?" + strings.Join(params, "&") + + _, body, resp, err := runWithRetriesForResp("POST", contentType, url, queryText) + if err != nil { + return "", 0, resp, err + } + + var r res + if err := json.Unmarshal(body, &r); err != nil { + return "", 0, resp, err + } + startTs := r.Extensions.Txn.StartTs + + // Remove the extensions. + r2 := res{ + Data: r.Data, + } + output, err := json.Marshal(r2) + + return string(output), startTs, resp, err +} + type mutationResponse struct { keys []string preds []string @@ -314,6 +346,61 @@ func runRequest(req *http.Request) (*x.QueryResWithData, []byte, error) { return qr, body, nil } +func runWithRetriesForResp(method, contentType, url string, body string) ( + *x.QueryResWithData, []byte, *http.Response, error) { + + req, err := createRequest(method, contentType, url, body) + if err != nil { + return nil, nil, nil, err + } + + qr, respBody, resp, err := runRequestForResp(req) + if err != nil && strings.Contains(err.Error(), "Token is expired") { + grootAccessJwt, grootRefreshJwt, err = testutil.HttpLogin(&testutil.LoginParams{ + Endpoint: addr + "/admin", + RefreshJwt: grootRefreshJwt, + }) + if err != nil { + return nil, nil, nil, err + } + + // create a new request since the previous request would have been closed upon the err + retryReq, err := createRequest(method, contentType, url, body) + if err != nil { + return nil, nil, resp, err + } + + return runRequestForResp(retryReq) + } + return qr, respBody, resp, err +} + +// attach the grootAccessJWT to the request and sends the http request +func runRequestForResp(req *http.Request) (*x.QueryResWithData, []byte, *http.Response, error) { + client := &http.Client{} + req.Header.Set("X-Dgraph-AccessToken", grootAccessJwt) + resp, err := client.Do(req) + if err != nil { + return nil, nil, resp, err + } + if status := resp.StatusCode; status != http.StatusOK { + return nil, nil, resp, errors.Errorf("Unexpected status code: %v", status) + } + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, nil, resp, errors.Errorf("unable to read from body: %v", err) + } + + qr := new(x.QueryResWithData) + json.Unmarshal(body, qr) // Don't check error. + if len(qr.Errors) > 0 { + return nil, nil, resp, errors.New(qr.Errors[0].Message) + } + return qr, body, resp, nil +} + func commitWithTs(keys, preds []string, ts uint64) error { url := addr + "/commit" if ts != 0 { @@ -452,6 +539,38 @@ func TestTransactionBasicNoPreds(t *testing.T) { require.NoError(t, err) require.Equal(t, `{"data":{"balances":[{"name":"Bob","balance":"110"}]}}`, data) } +func TestTransactionForCost(t *testing.T) { + require.NoError(t, dropAll()) + require.NoError(t, alterSchema(`name: string @index(term) .`)) + + q1 := ` + { + balances(func: anyofterms(name, "Alice Bob")) { + name + balance + } + } + ` + _, _, err := queryWithTs(q1, "application/graphql+-", "", 0) + require.NoError(t, err) + + m1 := ` + { + set { + _:alice "Bob" . + _:alice "110" . + _:bob "60" . + } + } + ` + + _, err = mutationWithTs(m1, "application/rdf", false, true, 0) + require.NoError(t, err) + + _, _, resp, err := queryWithTsForResp(q1, "application/graphql+-", "", 0) + require.NoError(t, err) + require.Equal(t, "2", resp.Header.Get(x.DgraphCostHeader)) +} func TestTransactionBasicOldCommitFormat(t *testing.T) { require.NoError(t, dropAll()) diff --git a/edgraph/server.go b/edgraph/server.go index 10cd34bffd6..db0f30a279a 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -20,6 +20,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "math" "net" "sort" @@ -36,6 +37,7 @@ import ( "go.opencensus.io/tag" "go.opencensus.io/trace" otrace "go.opencensus.io/trace" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -997,13 +999,18 @@ func (s *Server) doQuery(ctx context.Context, req *api.Request, doAuth AuthMode) EncodingNs: uint64(l.Json.Nanoseconds()), TotalNs: uint64((time.Since(l.Start)).Nanoseconds()), } - + md := metadata.Pairs(x.DgraphCostHeader, fmt.Sprint(resp.Metrics.NumUids["_total"])) + grpc.SendHeader(ctx, md) return resp, nil } func processQuery(ctx context.Context, qc *queryContext) (*api.Response, error) { resp := &api.Response{} if len(qc.req.Query) == 0 { + // No query, so make the query cost 0. + resp.Metrics = &api.Metrics{ + NumUids: map[string]uint64{"_total": 0}, + } return resp, nil } if ctx.Err() != nil { diff --git a/x/x.go b/x/x.go index 004d0f3b8f7..4bb71e247ce 100644 --- a/x/x.go +++ b/x/x.go @@ -146,6 +146,7 @@ const ( AccessControlAllowedHeaders = "X-Dgraph-AccessToken, " + "Content-Type, Content-Length, Accept-Encoding, Cache-Control, " + "X-CSRF-Token, X-Auth-Token, X-Requested-With" + DgraphCostHeader = "Dgraph-TouchedUids" // GraphqlPredicates is the json representation of the predicate reserved for graphql system. GraphqlPredicates = `