diff --git a/go.mod b/go.mod index de71f5dac4b..513d0f634fc 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/go-tpm v0.9.0 github.com/klauspost/compress v1.17.8 github.com/minio/highwayhash v1.0.2 - github.com/nats-io/jwt/v2 v2.5.6 + github.com/nats-io/jwt/v2 v2.5.7 github.com/nats-io/nats.go v1.34.1 github.com/nats-io/nkeys v0.4.7 github.com/nats-io/nuid v1.0.1 diff --git a/go.sum b/go.sum index 3c90966508f..05ada266284 100644 --- a/go.sum +++ b/go.sum @@ -5,8 +5,8 @@ github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0N github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/nats-io/jwt/v2 v2.5.6 h1:Cp618+z4q042sWqHiSoIHFT08OZtAskui0hTmRfmGGQ= -github.com/nats-io/jwt/v2 v2.5.6/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= +github.com/nats-io/jwt/v2 v2.5.7 h1:j5lH1fUXCnJnY8SsQeB/a/z9Azgu2bYIDvtPVNdxe2c= +github.com/nats-io/jwt/v2 v2.5.7/go.mod h1:ZdWS1nZa6WMZfFwwgpEaqBV8EPGVgOTDHN/wTbz0Y5A= github.com/nats-io/nats.go v1.34.1 h1:syWey5xaNHZgicYBemv0nohUPPmaLteiBEUT6Q5+F/4= github.com/nats-io/nats.go v1.34.1/go.mod h1:Ubdu4Nh9exXdSz0RVWRFBbRfrbSxOYd26oF0wkWclB8= github.com/nats-io/nkeys v0.4.7 h1:RwNJbbIdYCoClSDNY7QVKZlyb/wfT6ugvFCiKy6vDvI= diff --git a/server/auth.go b/server/auth.go index 97106343450..700e5741442 100644 --- a/server/auth.go +++ b/server/auth.go @@ -1464,7 +1464,8 @@ func validateAllowedConnectionTypes(m map[string]struct{}) error { switch ctuc { case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS, - jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS: + jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS, + jwt.ConnectionTypeInProcess: default: return fmt.Errorf("unknown connection type %q", ct) } diff --git a/server/client.go b/server/client.go index 2429cc5eb04..0ee3b7508c7 100644 --- a/server/client.go +++ b/server/client.go @@ -280,6 +280,7 @@ type client struct { trace bool echo bool noIcb bool + iproc bool // In-Process connection, set at creation and immutable. tags jwt.TagList nameTag string @@ -5971,7 +5972,8 @@ func convertAllowedConnectionTypes(cts []string) (map[string]struct{}, error) { switch i { case jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket, jwt.ConnectionTypeLeafnode, jwt.ConnectionTypeLeafnodeWS, - jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS: + jwt.ConnectionTypeMqtt, jwt.ConnectionTypeMqttWS, + jwt.ConnectionTypeInProcess: m[i] = struct{}{} default: unknown = append(unknown, i) @@ -5998,7 +6000,11 @@ func (c *client) connectionTypeAllowed(acts map[string]struct{}) bool { case CLIENT: switch c.clientType() { case NATS: - want = jwt.ConnectionTypeStandard + if c.iproc { + want = jwt.ConnectionTypeInProcess + } else { + want = jwt.ConnectionTypeStandard + } case WS: want = jwt.ConnectionTypeWebsocket case MQTT: diff --git a/server/client_test.go b/server/client_test.go index ed42b9568cc..a88fd5e2a0c 100644 --- a/server/client_test.go +++ b/server/client_test.go @@ -2962,3 +2962,96 @@ func TestRemoveHeaderIfPrefixPresent(t *testing.T) { t.Fatalf("Expected headers to be stripped, got %q", hdr) } } + +func TestInProcessAllowedConnectionType(t *testing.T) { + tmpl := ` + listen: "127.0.0.1:-1" + accounts { + A { users: [{user: "test", password: "pwd", allowed_connection_types: ["%s"]}] } + } + write_deadline: "500ms" + ` + for _, test := range []struct { + name string + ct string + inProcessOnly bool + }{ + {"conf inprocess", jwt.ConnectionTypeInProcess, true}, + {"conf standard", jwt.ConnectionTypeStandard, false}, + } { + t.Run(test.name, func(t *testing.T) { + conf := createConfFile(t, []byte(fmt.Sprintf(tmpl, test.ct))) + s, _ := RunServerWithConfig(conf) + defer s.Shutdown() + + // Create standard connection + nc, err := nats.Connect(s.ClientURL(), nats.UserInfo("test", "pwd")) + if test.inProcessOnly && err == nil { + nc.Close() + t.Fatal("Expected standard connection to fail, it did not") + } + // Works if nc is nil (which it will if only in-process are allowed) + nc.Close() + + // Create inProcess connection + nc, err = nats.Connect(_EMPTY_, nats.UserInfo("test", "pwd"), nats.InProcessServer(s)) + if !test.inProcessOnly && err == nil { + nc.Close() + t.Fatal("Expected in-process connection to fail, it did not") + } + // Works if nc is nil (which it will if only standard are allowed) + nc.Close() + }) + } + for _, test := range []struct { + name string + ct string + inProcessOnly bool + }{ + {"jwt inprocess", jwt.ConnectionTypeInProcess, true}, + {"jwt standard", jwt.ConnectionTypeStandard, false}, + } { + t.Run(test.name, func(t *testing.T) { + skp, _ := nkeys.FromSeed(oSeed) + spub, _ := skp.PublicKey() + + o := defaultServerOptions + o.TrustedKeys = []string{spub} + o.WriteDeadline = 500 * time.Millisecond + s := RunServer(&o) + defer s.Shutdown() + + buildMemAccResolver(s) + + kp, _ := nkeys.CreateAccount() + aPub, _ := kp.PublicKey() + claim := jwt.NewAccountClaims(aPub) + aJwt, err := claim.Encode(oKp) + require_NoError(t, err) + + addAccountToMemResolver(s, aPub, aJwt) + + creds := createUserWithLimit(t, kp, time.Time{}, + func(j *jwt.UserPermissionLimits) { + j.AllowedConnectionTypes.Add(test.ct) + }) + // Create standard connection + nc, err := nats.Connect(s.ClientURL(), nats.UserCredentials(creds)) + if test.inProcessOnly && err == nil { + nc.Close() + t.Fatal("Expected standard connection to fail, it did not") + } + // Works if nc is nil (which it will if only in-process are allowed) + nc.Close() + + // Create inProcess connection + nc, err = nats.Connect(_EMPTY_, nats.UserCredentials(creds), nats.InProcessServer(s)) + if !test.inProcessOnly && err == nil { + nc.Close() + t.Fatal("Expected in-process connection to fail, it did not") + } + // Works if nc is nil (which it will if only standard are allowed) + nc.Close() + }) + } +} diff --git a/server/server.go b/server/server.go index e6f81d5fed2..b0d5e778c4f 100644 --- a/server/server.go +++ b/server/server.go @@ -3078,7 +3078,16 @@ func (s *Server) createClientEx(conn net.Conn, inProcess bool) *client { } now := time.Now() - c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now} + c := &client{ + srv: s, + nc: conn, + opts: defaultOpts, + mpay: maxPay, + msubs: maxSubs, + start: now, + last: now, + iproc: inProcess, + } c.registerWithAccount(s.globalAccount())