Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dgraph): enabling TLS config in http zero #6691

Merged
merged 12 commits into from
Oct 22, 2020
90 changes: 78 additions & 12 deletions dgraph/cmd/zero/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
package zero

import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"time"

"github.com/dgraph-io/dgraph/protos/pb"
"github.com/dgraph-io/dgraph/x"
"github.com/gogo/protobuf/jsonpb"
"github.com/golang/glog"
"github.com/soheilhy/cmux"
)

// intFromQueryParam checks for name as a query param, converts it to uint64 and returns it.
Expand Down Expand Up @@ -239,23 +244,84 @@ func (st *state) pingResponse(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("OK"))
}

func (st *state) serveHTTP(l net.Listener) {
srv := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 600 * time.Second,
IdleTimeout: 2 * time.Minute,
func (st *state) startListenHttpAndHttps(l net.Listener) {
if Zero.Conf.GetString("tls_dir") == "" && Zero.Conf.GetString("tls_disabled_route") != "" {
glog.Fatal("--tls_disabled_route is provided as an option but tls_dir is empty. Please provide --tls_dir")
}

m := cmux.New(l)
startServers(m)

go func() {
defer st.zero.closer.Done()
err := srv.Serve(l)
glog.Errorf("Stopped taking more http(s) requests. Err: %v", err)
ctx, cancel := context.WithTimeout(context.Background(), 630*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
glog.Infoln("All http(s) requests finished.")
err := m.Serve()
if err != nil {
glog.Errorf("Http(s) shutdown err: %v", err)
glog.Errorf("error from cmux serve: %v", err)
}
}()
}

func startServers(m cmux.CMux) {
httpRule := m.Match(func(r io.Reader) bool {
//no tls config is provided. http is being used.
if opts.tlsDir == "" {
return true
}
//tls config is provided but none of the routes are disabled.
if len(opts.tlsDisabledRoutes) == 0 {
return false
}
path, ok := parseRequestPath(r)
// not able to parse the request. Let it be resolved via TLS
if !ok {
return false
}
for _, r := range opts.tlsDisabledRoutes {
if strings.HasPrefix(path, r) {
return true
}
}
return false
})
go startListen(httpRule)

// if tls is enabled, make tls encryption based connections as default
if Zero.Conf.GetString("tls_dir") != "" {
tlsCfg, err := x.LoadServerTLSConfig(Zero.Conf, "node.crt", "node.key")
x.Check(err)
if tlsCfg == nil {
glog.Fatalf("tls_dir is set but tls config provided is not correct. Please define correct variable --tls_dir")
}

httpsRule := m.Match(cmux.Any())
//this is chained listener. tls listener will decrypt the message and send it in plain text to HTTP server
go startListen(tls.NewListener(httpsRule, tlsCfg))
}
}

func startListen(l net.Listener) {
srv := &http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 600 * time.Second,
IdleTimeout: 2 * time.Minute,
}

err := srv.Serve(l)
glog.Errorf("Stopped taking more http(s) requests. Err: %v", err)
ctx, cancel := context.WithTimeout(context.Background(), 630*time.Second)
defer cancel()
err = srv.Shutdown(ctx)
glog.Infoln("All http(s) requests finished.")
if err != nil {
glog.Errorf("Http(s) shutdown err: %v", err)
}
}

func parseRequestPath(r io.Reader) (path string, ok bool) {
request, err := http.ReadRequest(bufio.NewReader(r))
if err != nil {
return "", false
}

return request.URL.Path, true
}
20 changes: 17 additions & 3 deletions dgraph/cmd/zero/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"net/http"
"os"
"os/signal"
"strings"
"syscall"
"time"

Expand Down Expand Up @@ -51,8 +52,9 @@ type options struct {
peer string
w string
rebalanceInterval time.Duration

totalCache int64
tlsDir string
tlsDisabledRoutes []string
totalCache int64
}

var opts options
Expand Down Expand Up @@ -88,6 +90,12 @@ instances to achieve high-availability.
flag.StringP("wal", "w", "zw", "Directory storing WAL.")
flag.Duration("rebalance_interval", 8*time.Minute, "Interval for trying a predicate move.")
flag.String("enterprise_license", "", "Path to the enterprise license file.")
// TLS configurations
flag.String("tls_dir", "", "Path to directory that has TLS certificates and keys.")
flag.Bool("tls_use_system_ca", true, "Include System CA into CA Certs.")
flag.String("tls_client_auth", "VERIFYIFGIVEN", "Enable TLS client authentication")
flag.String("tls_disabled_route", "", "comma separated zero endpoint which will be disabled from TLS encryption."+
"Valid values are /health,/state,/removeNode,/moveTablet,/assign,/enterpriseLicense,/debug.")
}

func setupListener(addr string, port int, kind string) (listener net.Listener, err error) {
Expand Down Expand Up @@ -160,6 +168,10 @@ func run() {
}

x.PrintVersion()
var tlsDisRoutes []string
if Zero.Conf.GetString("tls_disabled_route") != "" {
tlsDisRoutes = strings.Split(Zero.Conf.GetString("tls_disabled_route"), ",")
}

opts = options{
bindall: Zero.Conf.GetBool("bindall"),
Expand All @@ -170,6 +182,8 @@ func run() {
w: Zero.Conf.GetString("wal"),
rebalanceInterval: Zero.Conf.GetDuration("rebalance_interval"),
totalCache: int64(Zero.Conf.GetInt("cache_mb")),
tlsDir: Zero.Conf.GetString("tls_dir"),
tlsDisabledRoutes: tlsDisRoutes,
}
glog.Infof("Setting Config to: %+v", opts)

Expand Down Expand Up @@ -227,7 +241,7 @@ func run() {
// Initialize the servers.
var st state
st.serveGRPC(grpcListener, store)
st.serveHTTP(httpListener)
st.startListenHttpAndHttps(httpListener)

http.HandleFunc("/health", st.pingResponse)
http.HandleFunc("/state", st.getState)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ require (
github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829
github.com/prometheus/common v0.4.1 // indirect
github.com/prometheus/procfs v0.0.0-20190517135640-51af30a78b0e // indirect
github.com/soheilhy/cmux v0.1.4
github.com/spf13/cast v1.3.0
github.com/spf13/cobra v0.0.5
github.com/spf13/pflag v1.0.3
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1
github.com/smartystreets/goconvey v0.0.0-20190330032615-68dc04aab96a/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4 h1:0HKaf1o97UwFjHH9o5XsHUOF+tqmdA7KEzXLpiyaw0E=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
Expand Down
133 changes: 133 additions & 0 deletions tlstest/zero_https/all_routes_tls/all_routes_tls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package all_routes_tls

import (
"crypto/tls"
"crypto/x509"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
"io/ioutil"
"net/http"
"strings"
"testing"
"time"
)

type testCase struct {
url string
statusCode int
response string
}

var testCasesHttp = []testCase{
{
url: "http://localhost:6180/health",
response: "Client sent an HTTP request to an HTTPS server.\n",
statusCode: 400,
},
{
url: "http://localhost:6180/state",
response: "Client sent an HTTP request to an HTTPS server.\n",
statusCode: 400,
},
{
url: "http://localhost:6180/removeNode?id=2&group=0",
response: "Client sent an HTTP request to an HTTPS server.\n",
statusCode: 400,
},
}

func TestZeroWithAllRoutesTLSWithHTTPClient(t *testing.T) {
client := http.Client{
Timeout: time.Second * 10,
}
defer client.CloseIdleConnections()
for _, test := range testCasesHttp {
request, err := http.NewRequest("GET", test.url, nil)
require.NoError(t, err)
do, err := client.Do(request)
require.NoError(t, err)
if do != nil && do.StatusCode != test.statusCode {
t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode)
}

body := readResponseBody(t, do)
if test.response != string(body) {
t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response)
}
}
}

var testCasesHttps = []testCase{
{
url: "https://localhost:6180/health",
response: "OK",
statusCode: 200,
},
{
url: "https://localhost:6180/state",
response: "\"id\":\"1\",\"groupId\":0,\"addr\":\"zero1:5180\",\"leader\":true,\"amDead\":false",
statusCode: 200,
},
}

func TestZeroWithAllRoutesTLSWithTLSClient(t *testing.T) {
pool, err := generateCertPool("../../tls/ca.crt", true)
require.NoError(t, err)

tlsCfg := &tls.Config{RootCAs: pool, ServerName: "localhost", InsecureSkipVerify: true}
tr := &http.Transport{
IdleConnTimeout: 30 * time.Second,
DisableCompression: true,
TLSClientConfig: tlsCfg,
}
client := http.Client{
Transport: tr,
}

defer client.CloseIdleConnections()
for _, test := range testCasesHttps {
request, err := http.NewRequest("GET", test.url, nil)
require.NoError(t, err)
do, err := client.Do(request)
require.NoError(t, err)
if do != nil && do.StatusCode != test.statusCode {
t.Fatalf("status code is not same. Got: %d Expected: %d", do.StatusCode, test.statusCode)
}

body := readResponseBody(t, do)
if !strings.Contains(string(body), test.response) {
t.Fatalf("response is not same. Got: %s Expected: %s", string(body), test.response)
}
}
}

func readResponseBody(t *testing.T, do *http.Response) []byte {
defer func() { _ = do.Body.Close() }()
body, err := ioutil.ReadAll(do.Body)
require.NoError(t, err)
return body
}

func generateCertPool(certPath string, useSystemCA bool) (*x509.CertPool, error) {
var pool *x509.CertPool
if useSystemCA {
var err error
if pool, err = x509.SystemCertPool(); err != nil {
return nil, err
}
} else {
pool = x509.NewCertPool()
}

if len(certPath) > 0 {
caFile, err := ioutil.ReadFile(certPath)
if err != nil {
return nil, err
}
if !pool.AppendCertsFromPEM(caFile) {
return nil, errors.Errorf("error reading CA file %q", certPath)
}
}

return pool, nil
}
37 changes: 37 additions & 0 deletions tlstest/zero_https/all_routes_tls/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
version: "3.5"
services:
alpha1:
image: dgraph/dgraph:latest
container_name: alpha1
working_dir: /data/alpha1
labels:
cluster: test
ports:
- 8180:8180
- 9180:9180
volumes:
- type: bind
source: $GOPATH/bin
target: /gobin
read_only: true
command: /gobin/dgraph alpha -o 100 --my=alpha1:7180 --zero=zero1:5180 --logtostderr -v=2 --whitelist=10.0.0.0/8,172.16.0.0/12,192.168.0.0/16
zero1:
image: dgraph/dgraph:latest
container_name: zero1
working_dir: /data/zero1
labels:
cluster: test
ports:
- 5180:5180
- 6180:6180
volumes:
- type: bind
source: $GOPATH/bin
target: /gobin
read_only: true
- type: bind
source: ../../tls
target: /dgraph-tls
read_only: true
command: /gobin/dgraph zero -o 100 --idx=1 --my=zero1:5180 --tls_dir /dgraph-tls -v=2 --bindall
volumes: {}
Loading