Skip to content

Commit

Permalink
feat(dgraph): enabling TLS config in http zero (#6691)
Browse files Browse the repository at this point in the history
* enabling TLS config in http zero

* making zero https configured

* changing behaviour of cmux + adding test cases

* fixing zero address in test

* fixing docker files

* adding alpha in docker compose

* fixing test generate cert pool

* renaming functions based on review

* making zero https more vigilant with more checks

* changing the enabled to disabled flag

* fixing test case

* fixing zero cmd flag desc and refactoring test cases
  • Loading branch information
aman-bansal authored Oct 22, 2020
1 parent 5a6b136 commit 5482c60
Show file tree
Hide file tree
Showing 10 changed files with 521 additions and 15 deletions.
90 changes: 78 additions & 12 deletions dgraph/cmd/zero/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
package zero

import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"

"github.com/dgraph-io/dgraph/protos/pb"
"github.com/dgraph-io/dgraph/x"
"github.com/gogo/protobuf/jsonpb"
"github.com/golang/glog"
"github.com/soheilhy/cmux"
)

// intFromQueryParam checks for name as a query param, converts it to uint64 and returns it.
Expand Down Expand Up @@ -239,23 +244,84 @@ func (st *state) pingResponse(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("OK"))
}

func (st *state) serveHTTP(l net.Listener) {
srv := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 600 * time.Second,
IdleTimeout: 2 * time.Minute,
func (st *state) startListenHttpAndHttps(l net.Listener) {
if Zero.Conf.GetString("tls_dir") == "" && Zero.Conf.GetString("tls_disabled_route") != "" {
glog.Fatal("--tls_disabled_route is provided as an option but tls_dir is empty. Please provide --tls_dir")
}

m := cmux.New(l)
startServers(m)

go func() {
defer st.zero.closer.Done()
err := srv.Serve(l)
glog.Errorf("Stopped taking more http(s) requests. Err: %v", err)
ctx, cancel := context.WithTimeout(context.Background(), 630*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
glog.Infoln("All http(s) requests finished.")
err := m.Serve()
if err != nil {
glog.Errorf("Http(s) shutdown err: %v", err)
glog.Errorf("error from cmux serve: %v", err)
}
}()
}

func startServers(m cmux.CMux) {
httpRule := m.Match(func(r io.Reader) bool {
//no tls config is provided. http is being used.
if opts.tlsDir == "" {
return true
}
//tls config is provided but none of the routes are disabled.
if len(opts.tlsDisabledRoutes) == 0 {
return false
}
path, ok := parseRequestPath(r)
// not able to parse the request. Let it be resolved via TLS
if !ok {
return false
}
for _, r := range opts.tlsDisabledRoutes {
if strings.HasPrefix(path, r) {
return true
}
}
return false
})
go startListen(httpRule)

// if tls is enabled, make tls encryption based connections as default
if Zero.Conf.GetString("tls_dir") != "" {
tlsCfg, err := x.LoadServerTLSConfig(Zero.Conf, "node.crt", "node.key")
x.Check(err)
if tlsCfg == nil {
glog.Fatalf("tls_dir is set but tls config provided is not correct. Please define correct variable --tls_dir")
}

httpsRule := m.Match(cmux.Any())
//this is chained listener. tls listener will decrypt the message and send it in plain text to HTTP server
go startListen(tls.NewListener(httpsRule, tlsCfg))
}
}

func startListen(l net.Listener) {
srv := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 600 * time.Second,
IdleTimeout: 2 * time.Minute,
}

err := srv.Serve(l)
glog.Errorf("Stopped taking more http(s) requests. Err: %v", err)
ctx, cancel := context.WithTimeout(context.Background(), 630*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
glog.Infoln("All http(s) requests finished.")
if err != nil {
glog.Errorf("Http(s) shutdown err: %v", err)
}
}

func parseRequestPath(r io.Reader) (path string, ok bool) {
request, err := http.ReadRequest(bufio.NewReader(r))
if err != nil {
return "", false
}

return request.URL.Path, true
}
20 changes: 17 additions & 3 deletions dgraph/cmd/zero/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"

Expand Down Expand Up @@ -51,8 +52,9 @@ type options struct {
peer string
w string
rebalanceInterval time.Duration

totalCache int64
tlsDir string
tlsDisabledRoutes []string
totalCache int64
}

var opts options
Expand Down Expand Up @@ -88,6 +90,12 @@ instances to achieve high-availability.
flag.StringP("wal", "w", "zw", "Directory storing WAL.")
flag.Duration("rebalance_interval", 8*time.Minute, "Interval for trying a predicate move.")
flag.String("enterprise_license", "", "Path to the enterprise license file.")
// TLS configurations
flag.String("tls_dir", "", "Path to directory that has TLS certificates and keys.")
flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.")
flag.String("tls_client_auth", "VERIFYIFGIVEN", "Enable TLS client authentication")
flag.String("tls_disabled_route", "", "comma separated zero endpoint which will be disabled from TLS encryption."+
"Valid values are /health,/state,/removeNode,/moveTablet,/assign,/enterpriseLicense,/debug.")
}

func setupListener(addr string, port int, kind string) (listener net.Listener, err error) {
Expand Down Expand Up @@ -160,6 +168,10 @@ func run() {
}

x.PrintVersion()
var tlsDisRoutes []string
if Zero.Conf.GetString("tls_disabled_route") != "" {
tlsDisRoutes = strings.Split(Zero.Conf.GetString("tls_disabled_route"), ",")
}

opts = options{
bindall: Zero.Conf.GetBool("bindall"),
Expand All @@ -170,6 +182,8 @@ func run() {
w: Zero.Conf.GetString("wal"),
rebalanceInterval: Zero.Conf.GetDuration("rebalance_interval"),
totalCache: int64(Zero.Conf.GetInt("cache_mb")),
tlsDir: Zero.Conf.GetString("tls_dir"),
tlsDisabledRoutes: tlsDisRoutes,
}
glog.Infof("Setting Config to: %+v", opts)

Expand Down Expand Up @@ -227,7 +241,7 @@ func run() {
// Initialize the servers.
var st state
st.serveGRPC(grpcListener, store)
st.serveHTTP(httpListener)
st.startListenHttpAndHttps(httpListener)

http.HandleFunc("/health", st.pingResponse)
http.HandleFunc("/state", st.getState)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ require (
github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829
github.com/prometheus/common v0.4.1 // indirect
github.com/prometheus/procfs v0.0.0-20190517135640-51af30a78b0e // indirect
github.com/soheilhy/cmux v0.1.4
github.com/spf13/cast v1.3.0
github.com/spf13/cobra v0.0.5
github.com/spf13/pflag v1.0.3
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1
github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4 h1:0HKaf1o97UwFjHH9o5XsHUOF+tqmdA7KEzXLpiyaw0E=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
Expand Down
133 changes: 133 additions & 0 deletions tlstest/zero_https/all_routes_tls/all_routes_tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package all_routes_tls

import (
"crypto/tls"
"crypto/x509"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
)

type testCase struct {
url string
statusCode int
response string
}

var testCasesHttp = []testCase{
{
url: "http://localhost:6180/health",
response: "Client sent an HTTP request to an HTTPS server.\n",
statusCode: 400,
},
{
url: "http://localhost:6180/state",
response: "Client sent an HTTP request to an HTTPS server.\n",
statusCode: 400,
},
{
url: "http://localhost:6180/removeNode?id=2&group=0",
response: "Client sent an HTTP request to an HTTPS server.\n",
statusCode: 400,
},
}

func TestZeroWithAllRoutesTLSWithHTTPClient(t *testing.T) {
client := http.Client{
Timeout: time.Second * 10,
}
defer client.CloseIdleConnections()
for _, test := range testCasesHttp {
request, err := http.NewRequest("GET", test.url, nil)
require.NoError(t, err)
do, err := client.Do(request)
require.NoError(t, err)
if do != nil && do.StatusCode != test.statusCode {
t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode)
}

body := readResponseBody(t, do)
if test.response != string(body) {
t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response)
}
}
}

var testCasesHttps = []testCase{
{
url: "https://localhost:6180/health",
response: "OK",
statusCode: 200,
},
{
url: "https://localhost:6180/state",
response: "\"id\":\"1\",\"groupId\":0,\"addr\":\"zero1:5180\",\"leader\":true,\"amDead\":false",
statusCode: 200,
},
}

func TestZeroWithAllRoutesTLSWithTLSClient(t *testing.T) {
pool, err := generateCertPool("../../tls/ca.crt", true)
require.NoError(t, err)

tlsCfg := &tls.Config{RootCAs: pool, ServerName: "localhost", InsecureSkipVerify: true}
tr := &http.Transport{
IdleConnTimeout: 30 * time.Second,
DisableCompression: true,
TLSClientConfig: tlsCfg,
}
client := http.Client{
Transport: tr,
}

defer client.CloseIdleConnections()
for _, test := range testCasesHttps {
request, err := http.NewRequest("GET", test.url, nil)
require.NoError(t, err)
do, err := client.Do(request)
require.NoError(t, err)
if do != nil && do.StatusCode != test.statusCode {
t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode)
}

body := readResponseBody(t, do)
if !strings.Contains(string(body), test.response) {
t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response)
}
}
}

func readResponseBody(t *testing.T, do *http.Response) []byte {
defer func() { _ = do.Body.Close() }()
body, err := ioutil.ReadAll(do.Body)
require.NoError(t, err)
return body
}

func generateCertPool(certPath string, useSystemCA bool) (*x509.CertPool, error) {
var pool *x509.CertPool
if useSystemCA {
var err error
if pool, err = x509.SystemCertPool(); err != nil {
return nil, err
}
} else {
pool = x509.NewCertPool()
}

if len(certPath) > 0 {
caFile, err := ioutil.ReadFile(certPath)
if err != nil {
return nil, err
}
if !pool.AppendCertsFromPEM(caFile) {
return nil, errors.Errorf("error reading CA file %q", certPath)
}
}

return pool, nil
}
37 changes: 37 additions & 0 deletions tlstest/zero_https/all_routes_tls/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
version: "3.5"
services:
alpha1:
image: dgraph/dgraph:latest
container_name: alpha1
working_dir: /data/alpha1
labels:
cluster: test
ports:
- 8180:8180
- 9180:9180
volumes:
- type: bind
source: $GOPATH/bin
target: /gobin
read_only: true
command: /gobin/dgraph alpha -o 100 --my=alpha1:7180 --zero=zero1:5180 --logtostderr -v=2 --whitelist=10.0.0.0/8,172.16.0.0/12,192.168.0.0/16
zero1:
image: dgraph/dgraph:latest
container_name: zero1
working_dir: /data/zero1
labels:
cluster: test
ports:
- 5180:5180
- 6180:6180
volumes:
- type: bind
source: $GOPATH/bin
target: /gobin
read_only: true
- type: bind
source: ../../tls
target: /dgraph-tls
read_only: true
command: /gobin/dgraph zero -o 100 --idx=1 --my=zero1:5180 --tls_dir /dgraph-tls -v=2 --bindall
volumes: {}
Loading

0 comments on commit 5482c60

Please sign in to comment.