diff --git a/cmd/clef/main.go b/cmd/clef/main.go index 5b4a5bf28..00d135751 100644 --- a/cmd/clef/main.go +++ b/cmd/clef/main.go @@ -600,7 +600,7 @@ func signer(c *cli.Context) error { if err != nil { utils.Fatalf("Could not register API: %w", err) } - handler := node.NewHTTPHandlerStack(srv, cors, vhosts) + handler := node.NewHTTPHandlerStack(srv, cors, vhosts, nil) // start http server httpEndpoint := fmt.Sprintf("%s:%d", c.GlobalString(utils.HTTPListenAddrFlag.Name), c.Int(rpcPortFlag.Name)) diff --git a/cmd/gocore/main.go b/cmd/gocore/main.go index 5443734d2..d040cfb83 100644 --- a/cmd/gocore/main.go +++ b/cmd/gocore/main.go @@ -159,6 +159,8 @@ var ( utils.HTTPListenAddrFlag, utils.HTTPPortFlag, utils.HTTPCORSDomainFlag, + utils.AuthPortFlag, + utils.JWTSecretFlag, utils.HTTPVirtualHostsFlag, utils.LegacyRPCEnabledFlag, utils.LegacyRPCListenAddrFlag, diff --git a/cmd/gocore/retestxcb.go b/cmd/gocore/retestxcb.go index bfe56b616..c52305ef1 100644 --- a/cmd/gocore/retestxcb.go +++ b/cmd/gocore/retestxcb.go @@ -807,7 +807,7 @@ func retestxcb(ctx *cli.Context) error { if err != nil { utils.Fatalf("Could not register RPC apis: %w", err) } - handler := node.NewHTTPHandlerStack(srv, cors, vhosts) + handler := node.NewHTTPHandlerStack(srv, cors, vhosts, nil) // start http server var RetestethHTTPTimeouts = rpc.HTTPTimeouts{ diff --git a/cmd/gocore/usage.go b/cmd/gocore/usage.go index 0c543a295..0e04f913e 100644 --- a/cmd/gocore/usage.go +++ b/cmd/gocore/usage.go @@ -149,6 +149,7 @@ var AppHelpFlagGroups = []flagGroup{ Flags: []cli.Flag{ utils.IPCDisabledFlag, utils.IPCPathFlag, + utils.JWTSecretFlag, utils.HTTPEnabledFlag, utils.HTTPListenAddrFlag, utils.HTTPPortFlag, diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 880e48998..7ddcbba52 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -462,6 +462,16 @@ var ( Usage: "Sets a cap on transaction fee (in core) that can be sent via the RPC APIs (0 = no cap)", Value: xcb.DefaultConfig.RPCTxFeeCap, } + // Authenticated port settings + AuthPortFlag = cli.IntFlag{ + Name: "authrpc.port", + Usage: "Listening port for authenticated APIs", + Value: node.DefaultAuthPort, + } + JWTSecretFlag = cli.StringFlag{ + Name: "authrpc.jwtsecret", + Usage: "JWT secret (or path to a jwt secret) to use for authenticated RPC endpoints", + } // Logging and debug settings XcbStatsURLFlag = cli.StringFlag{ Name: "xcbstats", @@ -892,6 +902,11 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { cfg.HTTPCors = splitAndTrim(ctx.GlobalString(LegacyRPCCORSDomainFlag.Name)) log.Warn("The flag --rpccorsdomain is deprecated and will be removed in the future, please use --http.corsdomain") } + + if ctx.GlobalIsSet(AuthPortFlag.Name) { + cfg.AuthPort = ctx.GlobalInt(AuthPortFlag.Name) + } + if ctx.GlobalIsSet(HTTPCORSDomainFlag.Name) { cfg.HTTPCors = splitAndTrim(ctx.GlobalString(HTTPCORSDomainFlag.Name)) } @@ -1191,6 +1206,10 @@ func SetNodeConfig(ctx *cli.Context, cfg *node.Config) { setDataDir(ctx, cfg) setSmartCard(ctx, cfg) + if ctx.GlobalIsSet(JWTSecretFlag.Name) { + cfg.JWTSecret = ctx.GlobalString(JWTSecretFlag.Name) + } + if ctx.GlobalIsSet(ExternalSignerFlag.Name) { cfg.ExternalSigner = ctx.GlobalString(ExternalSignerFlag.Name) } diff --git a/go.mod b/go.mod index f50af3f90..be1283a7a 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/gballet/go-libpcsclite v0.0.0-20190607065134-2772fd86a8ff github.com/go-sourcemap/sourcemap v2.1.3+incompatible // indirect github.com/go-stack/stack v1.8.0 + github.com/golang-jwt/jwt/v4 v4.3.0 github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.3 github.com/gorilla/websocket v1.4.2 @@ -57,7 +58,7 @@ require ( golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d + golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f golang.org/x/text v0.3.7 golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba google.golang.org/protobuf v1.28.0 // indirect diff --git a/go.sum b/go.sum index 2891ecd29..05f057370 100644 --- a/go.sum +++ b/go.sum @@ -127,8 +127,9 @@ github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/me github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= -github.com/golang-jwt/jwt/v4 v4.0.0 h1:RAqyYixv1p7uEnocuy8P1nru5wprCh/MH2BIlW5z5/o= github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/golang-jwt/jwt/v4 v4.3.0 h1:kHL1vqdqWNfATmA0FNMdmZNMyZI1U6O31X4rlIPoBog= +github.com/golang-jwt/jwt/v4 v4.3.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -438,8 +439,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210324051608-47abb6519492/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d h1:/m5NbqQelATgoSPVC2Z23sR4kVNokFwDDyWh/3rGY+I= -golang.org/x/sys v0.0.0-20220708085239-5a0f0661e09d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f h1:v4INt8xihDGvnrfjMDVXGxw9wrfxYyCjk0KbXjhR55s= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/graphql/service.go b/graphql/service.go index 113c172ab..24d043db6 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -42,7 +42,7 @@ func newHandler(stack *node.Node, backend xcbapi.Backend, cors, vhosts []string) return err } h := &relay.Handler{Schema: s} - handler := node.NewHTTPHandlerStack(h, cors, vhosts) + handler := node.NewHTTPHandlerStack(h, cors, vhosts, nil) stack.RegisterHandler("GraphQL UI", "/graphql/ui", GraphiQL{}) stack.RegisterHandler("GraphQL", "/graphql", handler) diff --git a/node/api.go b/node/api.go index 4b5c9a8f3..25db18383 100644 --- a/node/api.go +++ b/node/api.go @@ -259,11 +259,12 @@ func (api *privateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str } // Enable WebSocket on the server. - server := api.node.wsServerForPort(*port) + server := api.node.wsServerForPort(*port, false) if err := server.setListenAddr(*host, *port); err != nil { return false, err } - if err := server.enableWS(api.node.rpcAPIs, config); err != nil { + openApis, _ := api.node.GetAPIs() + if err := server.enableWS(openApis, config); err != nil { return false, err } if err := server.start(); err != nil { diff --git a/node/config.go b/node/config.go index 69ef08c6e..47ca16bbd 100644 --- a/node/config.go +++ b/node/config.go @@ -42,6 +42,7 @@ import ( const ( datadirPrivateKey = "nodekey" // Path within the datadir to the node's private key + datadirJWTKey = "jwtsecret" // Path within the datadir to the node's jwt secret datadirDefaultKeyStore = "keystore" // Path within the datadir to the keystore datadirStaticNodes = "static-nodes.json" // Path within the datadir to the static node list datadirTrustedNodes = "trusted-nodes.json" // Path within the datadir to the trusted node list @@ -120,6 +121,9 @@ type Config struct { // for ephemeral nodes). HTTPPort int `toml:",omitempty"` + // Authport is the port number on which the authenticated API is provided. + AuthPort int `toml:",omitempty"` + // HTTPCors is the Cross-Origin Resource Sharing header to send to requesting // clients. Please be aware that CORS is a browser enforced security, it's fully // useless for custom HTTP clients. @@ -198,6 +202,9 @@ type Config struct { staticNodesWarning bool trustedNodesWarning bool oldGocoreResourceWarning bool + + // JWTSecret is the hex-encoded jwt secret. + JWTSecret string `toml:",omitempty"` } // IPCEndpoint resolves an IPC endpoint based on a configured value, taking into @@ -256,7 +263,7 @@ func (c *Config) HTTPEndpoint() string { // DefaultHTTPEndpoint returns the HTTP endpoint used by default. func DefaultHTTPEndpoint() string { - config := &Config{HTTPHost: DefaultHTTPHost, HTTPPort: DefaultHTTPPort} + config := &Config{HTTPHost: DefaultHTTPHost, HTTPPort: DefaultHTTPPort, AuthPort: DefaultAuthPort} return config.HTTPEndpoint() } diff --git a/node/defaults.go b/node/defaults.go index ee6fa33d0..0abced47d 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -34,12 +34,23 @@ const ( DefaultWSPort = 8546 // Default TCP port for the websocket RPC server DefaultGraphQLHost = "localhost" // Default host interface for the GraphQL server DefaultGraphQLPort = 8547 // Default TCP port for the GraphQL server + DefaultAuthHost = "localhost" // Default host interface for the authenticated apis + DefaultAuthPort = 8551 // Default port for the authenticated apis +) + +var ( + DefaultAuthCors = []string{"localhost"} // Default cors domain for the authenticated apis + DefaultAuthVhosts = []string{"localhost"} // Default virtual hosts for the authenticated apis + DefaultAuthOrigins = []string{"localhost"} // Default origins for the authenticated apis + DefaultAuthPrefix = "" // Default prefix for the authenticated apis + DefaultAuthModules = []string{"xcb"} ) // DefaultConfig contains reasonable default settings. var DefaultConfig = Config{ DataDir: DefaultDataDir(), HTTPPort: DefaultHTTPPort, + AuthPort: DefaultAuthPort, HTTPModules: []string{"net", "web3"}, HTTPVirtualHosts: []string{"localhost"}, HTTPTimeouts: rpc.DefaultHTTPTimeouts, diff --git a/node/endpoints.go b/node/endpoints.go index 179e41fa2..a412db4b0 100644 --- a/node/endpoints.go +++ b/node/endpoints.go @@ -59,8 +59,10 @@ func checkModuleAvailability(modules []string, apis []rpc.API) (bad, available [ } } for _, name := range modules { - if _, ok := availableSet[name]; !ok && name != rpc.MetadataApi { - bad = append(bad, name) + if _, ok := availableSet[name]; !ok { + if name != rpc.MetadataApi { + bad = append(bad, name) + } } } return bad, available diff --git a/node/jwt_handler.go b/node/jwt_handler.go new file mode 100644 index 000000000..92da1cd4e --- /dev/null +++ b/node/jwt_handler.go @@ -0,0 +1,78 @@ +// Copyright 2022 by the Authors +// This file is part of the go-core library. +// +// The go-core library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-core library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-core library. If not, see . + +package node + +import ( + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +type jwtHandler struct { + keyFunc func(token *jwt.Token) (interface{}, error) + next http.Handler +} + +// newJWTHandler creates a http.Handler with jwt authentication support. +func newJWTHandler(secret []byte, next http.Handler) http.Handler { + return &jwtHandler{ + keyFunc: func(token *jwt.Token) (interface{}, error) { + return secret, nil + }, + next: next, + } +} + +// ServeHTTP implements http.Handler +func (handler *jwtHandler) ServeHTTP(out http.ResponseWriter, r *http.Request) { + var ( + strToken string + claims jwt.RegisteredClaims + ) + if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { + strToken = strings.TrimPrefix(auth, "Bearer ") + } + if len(strToken) == 0 { + http.Error(out, "missing token", http.StatusForbidden) + return + } + // We explicitly set only HS256 allowed, and also disables the + // claim-check: the RegisteredClaims internally requires 'iat' to + // be no later than 'now', but we allow for a bit of drift. + token, err := jwt.ParseWithClaims(strToken, &claims, handler.keyFunc, + jwt.WithValidMethods([]string{"HS256"}), + jwt.WithoutClaimsValidation()) + + switch { + case err != nil: + http.Error(out, err.Error(), http.StatusForbidden) + case !token.Valid: + http.Error(out, "invalid token", http.StatusForbidden) + case !claims.VerifyExpiresAt(time.Now(), false): // optional + http.Error(out, "token is expired", http.StatusForbidden) + case claims.IssuedAt == nil: + http.Error(out, "missing issued-at", http.StatusForbidden) + case time.Since(claims.IssuedAt.Time) > 5*time.Second: + http.Error(out, "stale token", http.StatusForbidden) + case time.Until(claims.IssuedAt.Time) > 5*time.Second: + http.Error(out, "future token", http.StatusForbidden) + default: + handler.next.ServeHTTP(out, r) + } +} diff --git a/node/node.go b/node/node.go index bcac144c9..f4aa4fd26 100644 --- a/node/node.go +++ b/node/node.go @@ -17,8 +17,11 @@ package node import ( + crand "crypto/rand" "errors" "fmt" + "github.com/core-coin/go-core/common" + "github.com/core-coin/go-core/common/hexutil" "github.com/core-coin/go-core/core/led" "github.com/core-coin/go-core/xcbdb" "net/http" @@ -55,6 +58,8 @@ type Node struct { rpcAPIs []rpc.API // List of APIs currently provided by the node http *httpServer // ws *httpServer // + httpAuth *httpServer // + wsAuth *httpServer // ipc *ipcServer // Stores information about the ipc http server inprocHandler *rpc.Server // In-process RPC request handler to process the API requests @@ -146,7 +151,9 @@ func New(conf *Config) (*Node, error) { // Configure RPC servers. node.http = newHTTPServer(node.log, conf.HTTPTimeouts) + node.httpAuth = newHTTPServer(node.log, conf.HTTPTimeouts) node.ws = newHTTPServer(node.log, rpc.DefaultHTTPTimeouts) + node.wsAuth = newHTTPServer(node.log, rpc.DefaultHTTPTimeouts) node.ipc = newIPCServer(node.log, conf.IPCEndpoint()) return node, nil @@ -168,7 +175,8 @@ func (n *Node) Start() error { return ErrNodeStopped } n.state = runningState - err := n.startNetworking() + // open networking and RPC endpoints + err := n.openEndpoints() lifecycles := make([]Lifecycle, len(n.lifecycles)) copy(lifecycles, n.lifecycles) n.lock.Unlock() @@ -263,12 +271,14 @@ func (n *Node) doClose(errs []error) error { } } -// startNetworking starts all network endpoints. -func (n *Node) startNetworking() error { +// openEndpoints starts all network and RPC endpoints. +func (n *Node) openEndpoints() error { + // start networking endpoints n.log.Info("Starting peer-to-peer node", "instance", n.server.Name) if err := n.server.Start(); err != nil { return convertFileLockError(err) } + // start RPC endpoints err := n.startRPC() if err != nil { n.stopRPC() @@ -337,7 +347,50 @@ func (n *Node) closeDataDir() { } } -// configureRPC is a helper method to configure all the various RPC endpoints during node +// obtainJWTSecret loads the jwt-secret, either from the provided config, +// or from the default location. If neither of those are present, it generates +// a new secret and stores to the default location. +func (n *Node) obtainJWTSecret(cliParam string) ([]byte, error) { + var fileName string + if len(cliParam) > 0 { + // If a plaintext secret was provided via cli flags, use that + jwtSecret := common.FromHex(cliParam) + if len(jwtSecret) == 32 && strings.HasPrefix(cliParam, "0x") { + log.Warn("Plaintext JWT secret provided, please consider passing via file") + return jwtSecret, nil + } + // path provided + fileName = cliParam + } else { + // no path provided, use default + fileName = n.ResolvePath(datadirJWTKey) + } + // try reading from file + log.Debug("Reading JWT secret", "path", fileName) + if data, err := os.ReadFile(fileName); err == nil { + jwtSecret := common.FromHex(strings.TrimSpace(string(data))) + if len(jwtSecret) == 32 { + return jwtSecret, nil + } + log.Error("Invalid JWT secret", "path", fileName, "length", len(jwtSecret)) + return nil, errors.New("invalid JWT secret") + } + // Need to generate one + jwtSecret := make([]byte, 32) + crand.Read(jwtSecret) + // if we're in --dev mode, don't bother saving, just show it + if fileName == "" { + log.Info("Generated ephemeral JWT secret", "secret", hexutil.Encode(jwtSecret)) + return jwtSecret, nil + } + if err := os.WriteFile(fileName, []byte(hexutil.Encode(jwtSecret)), 0600); err != nil { + return nil, err + } + log.Info("Generated JWT secret", "path", fileName) + return jwtSecret, nil +} + +// startRPC is a helper method to configure all the various RPC endpoints during node // startup. It's not meant to be called at any time afterwards as it makes certain // assumptions about the state of the node. func (n *Node) startRPC() error { @@ -352,54 +405,124 @@ func (n *Node) startRPC() error { } } - // Configure HTTP. - if n.config.HTTPHost != "" { - config := httpConfig{ + var ( + servers []*httpServer + open, all = n.GetAPIs() + ) + + initHttp := func(server *httpServer, apis []rpc.API, port int) error { + if err := server.setListenAddr(n.config.HTTPHost, port); err != nil { + return err + } + if err := server.enableRPC(apis, httpConfig{ CorsAllowedOrigins: n.config.HTTPCors, Vhosts: n.config.HTTPVirtualHosts, Modules: n.config.HTTPModules, prefix: n.config.HTTPPathPrefix, + }); err != nil { + return err } - if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil { + servers = append(servers, server) + return nil + } + initWS := func(apis []rpc.API, port int) error { + server := n.wsServerForPort(port, false) + if err := server.setListenAddr(n.config.WSHost, port); err != nil { return err } - if err := n.http.enableRPC(n.rpcAPIs, config); err != nil { + if err := server.enableWS(n.rpcAPIs, wsConfig{ + Modules: n.config.WSModules, + Origins: n.config.WSOrigins, + prefix: n.config.WSPathPrefix, + }); err != nil { + return err + } + servers = append(servers, server) + return nil + } + + initAuth := func(apis []rpc.API, port int, secret []byte) error { + // Enable auth via HTTP + server := n.httpAuth + if err := server.setListenAddr(DefaultAuthHost, port); err != nil { + return err + } + if err := server.enableRPC(apis, httpConfig{ + CorsAllowedOrigins: DefaultAuthCors, + Vhosts: DefaultAuthVhosts, + Modules: DefaultAuthModules, + prefix: DefaultAuthPrefix, + jwtSecret: secret, + }); err != nil { + return err + } + servers = append(servers, server) + // Enable auth via WS + server = n.wsServerForPort(port, true) + if err := server.setListenAddr(DefaultAuthHost, port); err != nil { + return err + } + if err := server.enableWS(apis, wsConfig{ + Modules: DefaultAuthModules, + Origins: DefaultAuthOrigins, + prefix: DefaultAuthPrefix, + jwtSecret: secret, + }); err != nil { + return err + } + servers = append(servers, server) + return nil + } + // Set up HTTP. + if n.config.HTTPHost != "" { + // Configure legacy unauthenticated HTTP. + if err := initHttp(n.http, open, n.config.HTTPPort); err != nil { return err } } // Configure WebSocket. if n.config.WSHost != "" { - server := n.wsServerForPort(n.config.WSPort) - config := wsConfig{ - Modules: n.config.WSModules, - Origins: n.config.WSOrigins, - prefix: n.config.WSPathPrefix, + // legacy unauthenticated + if err := initWS(open, n.config.WSPort); err != nil { + return err } - if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil { + } + // Configure authenticated API + if len(open) != len(all) { + jwtSecret, err := n.obtainJWTSecret(n.config.JWTSecret) + if err != nil { return err } - if err := server.enableWS(n.rpcAPIs, config); err != nil { + if err := initAuth(all, n.config.AuthPort, jwtSecret); err != nil { return err } } - - if err := n.http.start(); err != nil { - return err + // Start the servers + for _, server := range servers { + if err := server.start(); err != nil { + return err + } } - return n.ws.start() + return nil } -func (n *Node) wsServerForPort(port int) *httpServer { - if n.config.HTTPHost == "" || n.http.port == port { - return n.http +func (n *Node) wsServerForPort(port int, authenticated bool) *httpServer { + httpServer, wsServer := n.http, n.ws + if authenticated { + httpServer, wsServer = n.httpAuth, n.wsAuth + } + if n.config.HTTPHost == "" || httpServer.port == port { + return httpServer } - return n.ws + return wsServer } func (n *Node) stopRPC() { n.http.stop() n.ws.stop() + n.httpAuth.stop() + n.wsAuth.stop() n.ipc.stop() n.stopInProc() } @@ -460,6 +583,17 @@ func (n *Node) RegisterAPIs(apis []rpc.API) { n.rpcAPIs = append(n.rpcAPIs, apis...) } +// GetAPIs return two sets of APIs, both the ones that do not require +// authentication, and the complete set +func (n *Node) GetAPIs() (unauthenticated, all []rpc.API) { + for _, api := range n.rpcAPIs { + if !api.Authenticated { + unauthenticated = append(unauthenticated, api) + } + } + return unauthenticated, n.rpcAPIs +} + // RegisterHandler mounts a handler on the given path on the canonical HTTP server. // // The name of the handler is shown in a log message when the HTTP server starts @@ -471,6 +605,7 @@ func (n *Node) RegisterHandler(name, path string, handler http.Handler) { if n.state != initializingState { panic("can't register HTTP handler on running/stopped node") } + n.http.mux.Handle(path, handler) n.http.handlerNames[path] = name } diff --git a/node/node_test.go b/node/node_test.go index 554050506..0b90a1264 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -393,7 +393,7 @@ func TestLifecycleTerminationGuarantee(t *testing.T) { // on the given prefix func TestRegisterHandler_Successful(t *testing.T) { node := createNode(t, 7878, 7979) - + defer node.Close() // create and mount handler handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("success")) @@ -483,7 +483,6 @@ func TestWebsocketHTTPOnSeparatePort_WSRequest(t *testing.T) { if !checkRPC(node.HTTPEndpoint()) { t.Fatalf("http request failed") } - } type rpcPrefixTest struct { @@ -578,13 +577,13 @@ func (test rpcPrefixTest) check(t *testing.T, node *Node) { } } for _, path := range test.wantWS { - err := wsRequest(t, wsBase+path, "") + err := wsRequest(t, wsBase+path) if err != nil { t.Errorf("Error: %s: WebSocket connection failed: %v", path, err) } } for _, path := range test.wantNoWS { - err := wsRequest(t, wsBase+path, "") + err := wsRequest(t, wsBase+path) if err == nil { t.Errorf("Error: %s: WebSocket connection succeeded for path in wantNoWS", path) } diff --git a/node/rpcstack.go b/node/rpcstack.go index 41d36951c..5ea8e4648 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -39,13 +39,15 @@ type httpConfig struct { CorsAllowedOrigins []string Vhosts []string prefix string // path prefix on which to mount http handler + jwtSecret []byte // optional JWT secret } // wsConfig is the JSON-RPC/Websocket configuration type wsConfig struct { - Origins []string - Modules []string - prefix string // path prefix on which to mount ws handler + Origins []string + Modules []string + prefix string // path prefix on which to mount ws handler + jwtSecret []byte // optional JWT secret } type rpcHandler struct { @@ -152,7 +154,7 @@ func (h *httpServer) start() error { } // Log http endpoint. h.log.Info("HTTP server started", - "endpoint", listener.Addr(), + "endpoint", listener.Addr(), "auth", (h.httpConfig.jwtSecret != nil), "prefix", h.httpConfig.prefix, "cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","), "vhosts", strings.Join(h.httpConfig.Vhosts, ","), @@ -280,7 +282,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { } h.httpConfig = config h.httpHandler.Store(&rpcHandler{ - Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts), + Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts, config.jwtSecret), server: srv, }) return nil @@ -312,7 +314,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { } h.wsConfig = config h.wsHandler.Store(&rpcHandler{ - Handler: srv.WebsocketHandler(config.Origins), + Handler: NewWSHandlerStack(srv.WebsocketHandler(config.Origins), config.jwtSecret), server: srv, }) return nil @@ -353,17 +355,28 @@ func (h *httpServer) wsAllowed() bool { // isWebsocket checks the header of an http request for a websocket upgrade request. func isWebsocket(r *http.Request) bool { return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" && - strings.ToLower(r.Header.Get("Connection")) == "upgrade" + strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") } // NewHTTPHandlerStack returns wrapped http-related handlers -func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler { +func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string, jwtSecret []byte) http.Handler { // Wrap the CORS-handler within a host-handler handler := newCorsHandler(srv, cors) handler = newVHostHandler(vhosts, handler) + if len(jwtSecret) != 0 { + handler = newJWTHandler(jwtSecret, handler) + } return newGzipHandler(handler) } +// NewWSHandlerStack returns a wrapped ws-related handler. +func NewWSHandlerStack(srv http.Handler, jwtSecret []byte) http.Handler { + if len(jwtSecret) != 0 { + return newJWTHandler(jwtSecret, srv) + } + return srv +} + func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler { // disable CORS support if user has not specified a custom CORS configuration if len(allowedOrigins) == 0 { diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 043cb8743..a487713f8 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -18,9 +18,11 @@ package node import ( "bytes" + "fmt" "github.com/core-coin/go-core/internal/testlog" "github.com/core-coin/go-core/log" "github.com/core-coin/go-core/rpc" + "github.com/golang-jwt/jwt/v4" "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "net/http" @@ -28,6 +30,7 @@ import ( "strconv" "strings" "testing" + "time" ) // TestCorsHandler makes sure CORS are properly handled on the http server. @@ -56,25 +59,120 @@ func TestVhosts(t *testing.T) { assert.Equal(t, resp2.StatusCode, http.StatusForbidden) } +type originTest struct { + spec string + expOk []string + expFail []string +} + +// splitAndTrim splits input separated by a comma +// and trims excessive white space from the substrings. +// Copied over from flags.go +func splitAndTrim(input string) (ret []string) { + l := strings.Split(input, ",") + for _, r := range l { + r = strings.TrimSpace(r) + if len(r) > 0 { + ret = append(ret, r) + } + } + return ret +} + // TestWebsocketOrigins makes sure the websocket origins are properly handled on the websocket server. func TestWebsocketOrigins(t *testing.T) { - srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: []string{"test"}}) - defer srv.stop() + tests := []originTest{ + { + spec: "*", // allow all + expOk: []string{"", "http://test", "https://test", "http://test:8540", "https://test:8540", + "http://test.com", "https://foo.test", "http://testa", "http://atestb:8540", "https://atestb:8540"}, + }, + { + spec: "test", + expOk: []string{"http://test", "https://test", "http://test:8540", "https://test:8540"}, + expFail: []string{"http://test.com", "https://foo.test", "http://testa", "http://atestb:8540", "https://atestb:8540"}, + }, + // scheme tests + { + spec: "https://test", + expOk: []string{"https://test", "https://test:9999"}, + expFail: []string{ + "test", // no scheme, required by spec + "http://test", // wrong scheme + "http://test.foo", "https://a.test.x", // subdomain variatoins + "http://testx:8540", "https://xtest:8540"}, + }, + // ip tests + { + spec: "https://12.34.56.78", + expOk: []string{"https://12.34.56.78", "https://12.34.56.78:8540"}, + expFail: []string{ + "http://12.34.56.78", // wrong scheme + "http://12.34.56.78:443", // wrong scheme + "http://1.12.34.56.78", // wrong 'domain name' + "http://12.34.56.78.a", // wrong 'domain name' + "https://87.65.43.21", "http://87.65.43.21:8540", "https://87.65.43.21:8540"}, + }, + // port tests + { + spec: "test:8540", + expOk: []string{"http://test:8540", "https://test:8540"}, + expFail: []string{ + "http://test", "https://test", // spec says port required + "http://test:8541", "https://test:8541", // wrong port + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + // scheme and port + { + spec: "https://test:8540", + expOk: []string{"https://test:8540"}, + expFail: []string{ + "https://test", // missing port + "http://test", // missing port, + wrong scheme + "http://test:8540", // wrong scheme + "http://test:8541", "https://test:8541", // wrong port + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + // several allowed origins + { + spec: "localhost,http://127.0.0.1", + expOk: []string{"localhost", "http://localhost", "https://localhost:8443", + "http://127.0.0.1", "http://127.0.0.1:8080"}, + expFail: []string{ + "https://127.0.0.1", // wrong scheme + "http://bad", "https://bad", "http://bad:8540", "https://bad:8540"}, + }, + } + for _, tc := range tests { + srv := createAndStartServer(t, &httpConfig{}, true, &wsConfig{Origins: splitAndTrim(tc.spec)}) + url := fmt.Sprintf("ws://%v", srv.listenAddr()) + for _, origin := range tc.expOk { + if err := wsRequest(t, url, "Origin", origin); err != nil { + t.Errorf("spec '%v', origin '%v': expected ok, got %v", tc.spec, origin, err) + } + } + for _, origin := range tc.expFail { + if err := wsRequest(t, url, "Origin", origin); err == nil { + t.Errorf("spec '%v', origin '%v': expected not to allow, got ok", tc.spec, origin) + } + } + srv.stop() + } +} + +// TestIsWebsocket tests if an incoming websocket upgrade request is handled properly. +func TestIsWebsocket(t *testing.T) { + r, _ := http.NewRequest("GET", "/", nil) - dialer := websocket.DefaultDialer - _, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{ - "Content-type": []string{"application/json"}, - "Sec-WebSocket-Version": []string{"13"}, - "Origin": []string{"test"}, - }) - assert.NoError(t, err) - - _, _, err = dialer.Dial("ws://"+srv.listenAddr(), http.Header{ - "Content-type": []string{"application/json"}, - "Sec-WebSocket-Version": []string{"13"}, - "Origin": []string{"bad"}, - }) - assert.Error(t, err) + assert.False(t, isWebsocket(r)) + r.Header.Set("upgrade", "websocket") + assert.False(t, isWebsocket(r)) + r.Header.Set("connection", "upgrade") + assert.True(t, isWebsocket(r)) + r.Header.Set("connection", "upgrade,keep-alive") + assert.True(t, isWebsocket(r)) + r.Header.Set("connection", " UPGRADE,keep-alive") + assert.True(t, isWebsocket(r)) } func Test_checkPath(t *testing.T) { @@ -146,13 +244,18 @@ func createAndStartServer(t *testing.T, conf *httpConfig, ws bool, wsConf *wsCon } // wsRequest attempts to open a WebSocket connection to the given URL. -func wsRequest(t *testing.T, url, browserOrigin string) error { +func wsRequest(t *testing.T, url string, extraHeaders ...string) error { t.Helper() - t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin) + //t.Logf("checking WebSocket on %s (origin %q)", url, browserOrigin) headers := make(http.Header) - if browserOrigin != "" { - headers.Set("Origin", browserOrigin) + // Apply extra headers. + if len(extraHeaders)%2 != 0 { + panic("odd extraHeaders length") + } + for i := 0; i < len(extraHeaders); i += 2 { + key, value := extraHeaders[i], extraHeaders[i+1] + headers.Set(key, value) } conn, _, err := websocket.DefaultDialer.Dial(url, headers) if conn != nil { @@ -194,3 +297,79 @@ func rpcRequest(t *testing.T, url string, extraHeaders ...string) *http.Response } return resp } + +type testClaim map[string]interface{} + +func (testClaim) Valid() error { + return nil +} + +func TestJWT(t *testing.T) { + var secret = []byte("secret") + issueToken := func(secret []byte, method jwt.SigningMethod, input map[string]interface{}) string { + if method == nil { + method = jwt.SigningMethodHS256 + } + ss, _ := jwt.NewWithClaims(method, testClaim(input)).SignedString(secret) + return ss + } + expOk := []string{ + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 4})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 4})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{ + "iat": time.Now().Unix(), + "exp": time.Now().Unix() + 2, + })), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{ + "iat": time.Now().Unix(), + "bar": "baz", + })), + } + expFail := []string{ + // future + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() + 6})), + // stale + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix() - 6})), + // wrong algo + fmt.Sprintf("Bearer %v", issueToken(secret, jwt.SigningMethodHS512, testClaim{"iat": time.Now().Unix() + 4})), + // expired + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix(), "exp": time.Now().Unix()})), + // missing mandatory iat + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{})), + // wrong secret + fmt.Sprintf("Bearer %v", issueToken([]byte("wrong"), nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken([]byte{}, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(nil, nil, testClaim{"iat": time.Now().Unix()})), + // Various malformed syntax + fmt.Sprintf("%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("bearer %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer: %v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer:%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer\t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + fmt.Sprintf("Bearer \t%v", issueToken(secret, nil, testClaim{"iat": time.Now().Unix()})), + } + srv := createAndStartServer(t, &httpConfig{jwtSecret: []byte("secret")}, + true, &wsConfig{Origins: []string{"*"}, jwtSecret: []byte("secret")}) + wsUrl := fmt.Sprintf("ws://%v", srv.listenAddr()) + htUrl := fmt.Sprintf("http://%v", srv.listenAddr()) + + for i, token := range expOk { + if err := wsRequest(t, wsUrl, "Authorization", token); err != nil { + t.Errorf("test %d-ws, token '%v': expected ok, got %v", i, token, err) + } + if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 200 { + t.Errorf("test %d-http, token '%v': expected ok, got %v", i, token, resp.StatusCode) + } + } + for i, token := range expFail { + if err := wsRequest(t, wsUrl, "Authorization", token); err == nil { + t.Errorf("tc %d-ws, token '%v': expected not to allow, got ok", i, token) + } + if resp := rpcRequest(t, htUrl, "Authorization", token); resp.StatusCode != 403 { + t.Errorf("tc %d-http, token '%v': expected not to allow, got %v", i, token, resp.StatusCode) + } + } + srv.stop() +} diff --git a/rpc/types.go b/rpc/types.go index 91e129ffd..335e4c8fa 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -29,10 +29,11 @@ import ( // API describes the set of methods offered over the RPC interface type API struct { - Namespace string // namespace under which the rpc methods of Service are exposed - Version string // api version for DApp's - Service interface{} // receiver instance which holds the methods - Public bool // indication if the methods must be considered safe for public use + Namespace string // namespace under which the rpc methods of Service are exposed + Version string // api version for DApp's + Service interface{} // receiver instance which holds the methods + Public bool // indication if the methods must be considered safe for public use + Authenticated bool // whether the api should only be available behind authentication. } // Error wraps RPC errors, which contain an error code in addition to the message. diff --git a/rpc/websocket.go b/rpc/websocket.go index e4baba629..e8c806954 100644 --- a/rpc/websocket.go +++ b/rpc/websocket.go @@ -72,14 +72,14 @@ func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { allowAllOrigins = true } if origin != "" { - origins.Add(strings.ToLower(origin)) + origins.Add(origin) } } // allow localhost if no allowedOrigins are specified. if len(origins.ToSlice()) == 0 { origins.Add("http://localhost") if hostname, err := os.Hostname(); err == nil { - origins.Add("http://" + strings.ToLower(hostname)) + origins.Add("http://" + hostname) } } log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) @@ -94,7 +94,7 @@ func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { } // Verify origin against whitelist. origin := strings.ToLower(req.Header.Get("Origin")) - if allowAllOrigins || origins.Contains(origin) { + if allowAllOrigins || originIsAllowed(origins, origin) { return true } log.Warn("Rejected WebSocket connection", "origin", origin) @@ -117,6 +117,65 @@ func (e wsHandshakeError) Error() string { return s } +func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { + it := allowedOrigins.Iterator() + for origin := range it.C { + if ruleAllowsOrigin(origin.(string), browserOrigin) { + return true + } + } + return false +} + +func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { + var ( + allowedScheme, allowedHostname, allowedPort string + browserScheme, browserHostname, browserPort string + err error + ) + allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) + if err != nil { + log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err) + return false + } + browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) + if err != nil { + log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err) + return false + } + if allowedScheme != "" && allowedScheme != browserScheme { + return false + } + if allowedHostname != "" && allowedHostname != browserHostname { + return false + } + if allowedPort != "" && allowedPort != browserPort { + return false + } + return true +} + +func parseOriginURL(origin string) (string, string, string, error) { + parsedURL, err := url.Parse(strings.ToLower(origin)) + if err != nil { + return "", "", "", err + } + var scheme, hostname, port string + if strings.Contains(origin, "://") { + scheme = parsedURL.Scheme + hostname = parsedURL.Hostname() + port = parsedURL.Port() + } else { + scheme = "" + hostname = parsedURL.Scheme + port = parsedURL.Opaque + if hostname == "" { + hostname = origin + } + } + return scheme, hostname, port, nil +} + // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint using the provided dialer. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { diff --git a/rpc/websocket_test.go b/rpc/websocket_test.go index 132a21182..8aff1e4a8 100644 --- a/rpc/websocket_test.go +++ b/rpc/websocket_test.go @@ -72,7 +72,7 @@ func TestWebsocketOriginCheck(t *testing.T) { // Connections without origin header should work. client, err = DialWebsocket(context.Background(), wsURL, "") if err != nil { - t.Fatal("error for empty origin") + t.Fatalf("error for empty origin: %v", err) } client.Close() }