From 8e51d7418cffe64ea959a5b67b224176b5cad45d Mon Sep 17 00:00:00 2001 From: Aaron Lehmann Date: Wed, 21 Dec 2016 10:49:23 -0800 Subject: [PATCH] Expose needed services on the control socket For swarm mode to function without exposing a TCP port, we need services such as the dispatcher and node CA to be exposed on the control socket (i.e. a unix socket). This commit changes the manager to expose those services, and changes the raft proxy to inject some information into the context when calling the handler directly that identifies the local node. The authorization code in "ca" is updated to check for this information on the context and make use of it, instead of returning an error from RemoteNode. Also, the CA server now renewing a certificate over the control socket. Signed-off-by: Aaron Lehmann --- api/ca.pb.go | 80 +++++--- api/control.pb.go | 172 ++++++++++++++---- api/dispatcher.pb.go | 106 ++++++++--- api/health.pb.go | 34 ++-- api/logbroker.pb.go | 122 +++++++++---- api/raft.pb.go | 80 +++++--- api/resource.pb.go | 40 ++-- ca/auth.go | 18 ++ ca/server.go | 9 + manager/manager.go | 61 +++++-- protobuf/plugin/raftproxy/raftproxy.go | 97 ++++++++-- .../plugin/raftproxy/test/raftproxy_test.go | 2 +- protobuf/plugin/raftproxy/test/service.pb.go | 134 ++++++++++---- 13 files changed, 719 insertions(+), 236 deletions(-) diff --git a/api/ca.pb.go b/api/ca.pb.go index 619421b3b0..343a182e7c 100644 --- a/api/ca.pb.go +++ b/api/ca.pb.go @@ -836,12 +836,12 @@ func encodeVarintCa(data []byte, offset int, v uint64) int { } type raftProxyCAServer struct { - local CAServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local CAServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) CAServer { +func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) CAServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -858,18 +858,24 @@ func NewRaftProxyCAServer(local CAServer, connSelector raftselector.ConnProvider md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyCAServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyCAServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyCAServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -906,11 +912,15 @@ func (p *raftProxyCAServer) GetRootCACertificate(ctx context.Context, r *GetRoot conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetRootCACertificate(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -937,11 +947,15 @@ func (p *raftProxyCAServer) GetUnlockKey(ctx context.Context, r *GetUnlockKeyReq conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetUnlockKey(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -964,12 +978,12 @@ func (p *raftProxyCAServer) GetUnlockKey(ctx context.Context, r *GetUnlockKeyReq } type raftProxyNodeCAServer struct { - local NodeCAServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local NodeCAServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) NodeCAServer { +func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) NodeCAServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -986,18 +1000,24 @@ func NewRaftProxyNodeCAServer(local NodeCAServer, connSelector raftselector.Conn md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyNodeCAServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyNodeCAServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyNodeCAServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1034,11 +1054,15 @@ func (p *raftProxyNodeCAServer) IssueNodeCertificate(ctx context.Context, r *Iss conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.IssueNodeCertificate(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1065,11 +1089,15 @@ func (p *raftProxyNodeCAServer) NodeCertificateStatus(ctx context.Context, r *No conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.NodeCertificateStatus(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/api/control.pb.go b/api/control.pb.go index 6f36208c2c..17b4f4113d 100644 --- a/api/control.pb.go +++ b/api/control.pb.go @@ -5256,12 +5256,12 @@ func encodeVarintControl(data []byte, offset int, v uint64) int { } type raftProxyControlServer struct { - local ControlServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local ControlServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) ControlServer { +func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) ControlServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -5278,18 +5278,24 @@ func NewRaftProxyControlServer(local ControlServer, connSelector raftselector.Co md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyControlServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyControlServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyControlServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -5326,11 +5332,15 @@ func (p *raftProxyControlServer) GetNode(ctx context.Context, r *GetNodeRequest) conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetNode(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5357,11 +5367,15 @@ func (p *raftProxyControlServer) ListNodes(ctx context.Context, r *ListNodesRequ conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListNodes(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5388,11 +5402,15 @@ func (p *raftProxyControlServer) UpdateNode(ctx context.Context, r *UpdateNodeRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateNode(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5419,11 +5437,15 @@ func (p *raftProxyControlServer) RemoveNode(ctx context.Context, r *RemoveNodeRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveNode(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5450,11 +5472,15 @@ func (p *raftProxyControlServer) GetTask(ctx context.Context, r *GetTaskRequest) conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetTask(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5481,11 +5507,15 @@ func (p *raftProxyControlServer) ListTasks(ctx context.Context, r *ListTasksRequ conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListTasks(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5512,11 +5542,15 @@ func (p *raftProxyControlServer) RemoveTask(ctx context.Context, r *RemoveTaskRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveTask(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5543,11 +5577,15 @@ func (p *raftProxyControlServer) GetService(ctx context.Context, r *GetServiceRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetService(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5574,11 +5612,15 @@ func (p *raftProxyControlServer) ListServices(ctx context.Context, r *ListServic conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListServices(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5605,11 +5647,15 @@ func (p *raftProxyControlServer) CreateService(ctx context.Context, r *CreateSer conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.CreateService(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5636,11 +5682,15 @@ func (p *raftProxyControlServer) UpdateService(ctx context.Context, r *UpdateSer conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateService(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5667,11 +5717,15 @@ func (p *raftProxyControlServer) RemoveService(ctx context.Context, r *RemoveSer conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveService(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5698,11 +5752,15 @@ func (p *raftProxyControlServer) GetNetwork(ctx context.Context, r *GetNetworkRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5729,11 +5787,15 @@ func (p *raftProxyControlServer) ListNetworks(ctx context.Context, r *ListNetwor conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListNetworks(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5760,11 +5822,15 @@ func (p *raftProxyControlServer) CreateNetwork(ctx context.Context, r *CreateNet conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.CreateNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5791,11 +5857,15 @@ func (p *raftProxyControlServer) RemoveNetwork(ctx context.Context, r *RemoveNet conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5822,11 +5892,15 @@ func (p *raftProxyControlServer) GetCluster(ctx context.Context, r *GetClusterRe conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetCluster(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5853,11 +5927,15 @@ func (p *raftProxyControlServer) ListClusters(ctx context.Context, r *ListCluste conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListClusters(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5884,11 +5962,15 @@ func (p *raftProxyControlServer) UpdateCluster(ctx context.Context, r *UpdateClu conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateCluster(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5915,11 +5997,15 @@ func (p *raftProxyControlServer) GetSecret(ctx context.Context, r *GetSecretRequ conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetSecret(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5946,11 +6032,15 @@ func (p *raftProxyControlServer) UpdateSecret(ctx context.Context, r *UpdateSecr conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateSecret(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -5977,11 +6067,15 @@ func (p *raftProxyControlServer) ListSecrets(ctx context.Context, r *ListSecrets conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ListSecrets(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -6008,11 +6102,15 @@ func (p *raftProxyControlServer) CreateSecret(ctx context.Context, r *CreateSecr conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.CreateSecret(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -6039,11 +6137,15 @@ func (p *raftProxyControlServer) RemoveSecret(ctx context.Context, r *RemoveSecr conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.RemoveSecret(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/api/dispatcher.pb.go b/api/dispatcher.pb.go index 751c48d37c..7c5c70b5e7 100644 --- a/api/dispatcher.pb.go +++ b/api/dispatcher.pb.go @@ -1670,12 +1670,12 @@ func encodeVarintDispatcher(data []byte, offset int, v uint64) int { } type raftProxyDispatcherServer struct { - local DispatcherServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local DispatcherServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) DispatcherServer { +func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) DispatcherServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1692,18 +1692,24 @@ func NewRaftProxyDispatcherServer(local DispatcherServer, connSelector raftselec md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyDispatcherServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyDispatcherServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyDispatcherServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1735,17 +1741,33 @@ func (p *raftProxyDispatcherServer) pollNewLeaderConn(ctx context.Context) (*grp } } -func (p *raftProxyDispatcherServer) Session(r *SessionRequest, stream Dispatcher_SessionServer) error { +type Dispatcher_SessionServerWrapper struct { + Dispatcher_SessionServer + ctx context.Context +} +func (s Dispatcher_SessionServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyDispatcherServer) Session(r *SessionRequest, stream Dispatcher_SessionServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.Session(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := Dispatcher_SessionServerWrapper{ + Dispatcher_SessionServer: stream, + ctx: ctx, + } + return p.local.Session(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1775,11 +1797,15 @@ func (p *raftProxyDispatcherServer) Heartbeat(ctx context.Context, r *HeartbeatR conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Heartbeat(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1806,11 +1832,15 @@ func (p *raftProxyDispatcherServer) UpdateTaskStatus(ctx context.Context, r *Upd conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.UpdateTaskStatus(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1832,17 +1862,33 @@ func (p *raftProxyDispatcherServer) UpdateTaskStatus(ctx context.Context, r *Upd return resp, err } -func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_TasksServer) error { +type Dispatcher_TasksServerWrapper struct { + Dispatcher_TasksServer + ctx context.Context +} + +func (s Dispatcher_TasksServerWrapper) Context() context.Context { + return s.ctx +} +func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_TasksServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.Tasks(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := Dispatcher_TasksServerWrapper{ + Dispatcher_TasksServer: stream, + ctx: ctx, + } + return p.local.Tasks(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1867,17 +1913,33 @@ func (p *raftProxyDispatcherServer) Tasks(r *TasksRequest, stream Dispatcher_Tas return nil } -func (p *raftProxyDispatcherServer) Assignments(r *AssignmentsRequest, stream Dispatcher_AssignmentsServer) error { +type Dispatcher_AssignmentsServerWrapper struct { + Dispatcher_AssignmentsServer + ctx context.Context +} + +func (s Dispatcher_AssignmentsServerWrapper) Context() context.Context { + return s.ctx +} +func (p *raftProxyDispatcherServer) Assignments(r *AssignmentsRequest, stream Dispatcher_AssignmentsServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.Assignments(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := Dispatcher_AssignmentsServerWrapper{ + Dispatcher_AssignmentsServer: stream, + ctx: ctx, + } + return p.local.Assignments(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } diff --git a/api/health.pb.go b/api/health.pb.go index 13c40143df..5e53c97bd0 100644 --- a/api/health.pb.go +++ b/api/health.pb.go @@ -321,12 +321,12 @@ func encodeVarintHealth(data []byte, offset int, v uint64) int { } type raftProxyHealthServer struct { - local HealthServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local HealthServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) HealthServer { +func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) HealthServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -343,18 +343,24 @@ func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.Conn md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyHealthServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyHealthServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyHealthServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -391,11 +397,15 @@ func (p *raftProxyHealthServer) Check(ctx context.Context, r *HealthCheckRequest conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Check(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/api/logbroker.pb.go b/api/logbroker.pb.go index a066add9eb..b3dc176c96 100644 --- a/api/logbroker.pb.go +++ b/api/logbroker.pb.go @@ -1279,12 +1279,12 @@ func encodeVarintLogbroker(data []byte, offset int, v uint64) int { } type raftProxyLogsServer struct { - local LogsServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local LogsServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) LogsServer { +func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) LogsServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1301,18 +1301,24 @@ func NewRaftProxyLogsServer(local LogsServer, connSelector raftselector.ConnProv md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyLogsServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyLogsServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyLogsServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1344,17 +1350,33 @@ func (p *raftProxyLogsServer) pollNewLeaderConn(ctx context.Context) (*grpc.Clie } } -func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs_SubscribeLogsServer) error { +type Logs_SubscribeLogsServerWrapper struct { + Logs_SubscribeLogsServer + ctx context.Context +} +func (s Logs_SubscribeLogsServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs_SubscribeLogsServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.SubscribeLogs(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := Logs_SubscribeLogsServerWrapper{ + Logs_SubscribeLogsServer: stream, + ctx: ctx, + } + return p.local.SubscribeLogs(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1380,12 +1402,12 @@ func (p *raftProxyLogsServer) SubscribeLogs(r *SubscribeLogsRequest, stream Logs } type raftProxyLogBrokerServer struct { - local LogBrokerServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local LogBrokerServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) LogBrokerServer { +func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) LogBrokerServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1402,18 +1424,24 @@ func NewRaftProxyLogBrokerServer(local LogBrokerServer, connSelector raftselecto md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyLogBrokerServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyLogBrokerServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyLogBrokerServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1445,17 +1473,33 @@ func (p *raftProxyLogBrokerServer) pollNewLeaderConn(ctx context.Context) (*grpc } } -func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsRequest, stream LogBroker_ListenSubscriptionsServer) error { +type LogBroker_ListenSubscriptionsServerWrapper struct { + LogBroker_ListenSubscriptionsServer + ctx context.Context +} +func (s LogBroker_ListenSubscriptionsServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsRequest, stream LogBroker_ListenSubscriptionsServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.ListenSubscriptions(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := LogBroker_ListenSubscriptionsServerWrapper{ + LogBroker_ListenSubscriptionsServer: stream, + ctx: ctx, + } + return p.local.ListenSubscriptions(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1480,17 +1524,33 @@ func (p *raftProxyLogBrokerServer) ListenSubscriptions(r *ListenSubscriptionsReq return nil } -func (p *raftProxyLogBrokerServer) PublishLogs(stream LogBroker_PublishLogsServer) error { +type LogBroker_PublishLogsServerWrapper struct { + LogBroker_PublishLogsServer + ctx context.Context +} +func (s LogBroker_PublishLogsServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyLogBrokerServer) PublishLogs(stream LogBroker_PublishLogsServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.PublishLogs(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := LogBroker_PublishLogsServerWrapper{ + LogBroker_PublishLogsServer: stream, + ctx: ctx, + } + return p.local.PublishLogs(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } diff --git a/api/raft.pb.go b/api/raft.pb.go index e824d66a9c..8a96952f76 100644 --- a/api/raft.pb.go +++ b/api/raft.pb.go @@ -1498,12 +1498,12 @@ func encodeVarintRaft(data []byte, offset int, v uint64) int { } type raftProxyRaftServer struct { - local RaftServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local RaftServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) RaftServer { +func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) RaftServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1520,18 +1520,24 @@ func NewRaftProxyRaftServer(local RaftServer, connSelector raftselector.ConnProv md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyRaftServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyRaftServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyRaftServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1568,11 +1574,15 @@ func (p *raftProxyRaftServer) ProcessRaftMessage(ctx context.Context, r *Process conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ProcessRaftMessage(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1599,11 +1609,15 @@ func (p *raftProxyRaftServer) ResolveAddress(ctx context.Context, r *ResolveAddr conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.ResolveAddress(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1626,12 +1640,12 @@ func (p *raftProxyRaftServer) ResolveAddress(ctx context.Context, r *ResolveAddr } type raftProxyRaftMembershipServer struct { - local RaftMembershipServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local RaftMembershipServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) RaftMembershipServer { +func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) RaftMembershipServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1648,18 +1662,24 @@ func NewRaftProxyRaftMembershipServer(local RaftMembershipServer, connSelector r md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyRaftMembershipServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyRaftMembershipServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyRaftMembershipServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1696,11 +1716,15 @@ func (p *raftProxyRaftMembershipServer) Join(ctx context.Context, r *JoinRequest conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Join(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1727,11 +1751,15 @@ func (p *raftProxyRaftMembershipServer) Leave(ctx context.Context, r *LeaveReque conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Leave(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/api/resource.pb.go b/api/resource.pb.go index 52d1e4e4ab..a764a6ccee 100644 --- a/api/resource.pb.go +++ b/api/resource.pb.go @@ -451,12 +451,12 @@ func encodeVarintResource(data []byte, offset int, v uint64) int { } type raftProxyResourceAllocatorServer struct { - local ResourceAllocatorServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local ResourceAllocatorServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) ResourceAllocatorServer { +func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) ResourceAllocatorServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -473,18 +473,24 @@ func NewRaftProxyResourceAllocatorServer(local ResourceAllocatorServer, connSele md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyResourceAllocatorServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyResourceAllocatorServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyResourceAllocatorServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -521,11 +527,15 @@ func (p *raftProxyResourceAllocatorServer) AttachNetwork(ctx context.Context, r conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.AttachNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -552,11 +562,15 @@ func (p *raftProxyResourceAllocatorServer) DetachNetwork(ctx context.Context, r conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.DetachNetwork(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } diff --git a/ca/auth.go b/ca/auth.go index d81b543da0..bc7c629a54 100644 --- a/ca/auth.go +++ b/ca/auth.go @@ -16,6 +16,13 @@ import ( "google.golang.org/grpc/peer" ) +type localRequestKeyType struct{} + +// LocalRequestKey is a context key to mark a request that originating on the +// local node. The assocated value is a RemoteNodeInfo structure describing the +// local node. +var LocalRequestKey = localRequestKeyType{} + // LogTLSState logs information about the TLS connection and remote peers func LogTLSState(ctx context.Context, tlsState *tls.ConnectionState) { if tlsState == nil { @@ -189,6 +196,17 @@ type RemoteNodeInfo struct { // well as the forwarder's ID. This function does not do authorization checks - // it only looks up the node ID. func RemoteNode(ctx context.Context) (RemoteNodeInfo, error) { + // If we have a value on the context that marks this as a local + // request, we return the node info from the context. + localNodeInfo := ctx.Value(LocalRequestKey) + + if localNodeInfo != nil { + nodeInfo, ok := localNodeInfo.(RemoteNodeInfo) + if ok { + return nodeInfo, nil + } + } + certSubj, err := certSubjectFromContext(ctx) if err != nil { return RemoteNodeInfo{}, err diff --git a/ca/server.go b/ca/server.go index fa55d38534..29fc5a217f 100644 --- a/ca/server.go +++ b/ca/server.go @@ -211,6 +211,15 @@ func (s *Server) IssueNodeCertificate(ctx context.Context, request *api.IssueNod blacklistedCerts = clusters[0].BlacklistedCertificates } + // Renewing the cert with a local (unix socket) is always valid. + localNodeInfo := ctx.Value(LocalRequestKey) + if localNodeInfo != nil { + nodeInfo, ok := localNodeInfo.(RemoteNodeInfo) + if ok && nodeInfo.NodeID != "" { + return s.issueRenewCertificate(ctx, nodeInfo.NodeID, request.CSR) + } + } + // If the remote node is a worker (either forwarded by a manager, or calling directly), // issue a renew worker certificate entry with the correct ID nodeID, err := AuthorizeForwardedRoleAndOrg(ctx, []string{WorkerRole}, []string{ManagerRole}, s.securityConfig.ClientTLSCreds.Organization(), blacklistedCerts) diff --git a/manager/manager.go b/manager/manager.go index 649f229532..39e17295b2 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -328,23 +328,46 @@ func (m *Manager) Run(parent context.Context) error { authenticatedHealthAPI := api.NewAuthenticatedWrapperHealthServer(healthServer, authorize) authenticatedRaftMembershipAPI := api.NewAuthenticatedWrapperRaftMembershipServer(m.raftNode, authorize) - proxyDispatcherAPI := api.NewRaftProxyDispatcherServer(authenticatedDispatcherAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyCAAPI := api.NewRaftProxyCAServer(authenticatedCAAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyNodeCAAPI := api.NewRaftProxyNodeCAServer(authenticatedNodeCAAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyRaftMembershipAPI := api.NewRaftProxyRaftMembershipServer(authenticatedRaftMembershipAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(authenticatedResourceAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - proxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(authenticatedLogBrokerAPI, m.raftNode, ca.WithMetadataForwardTLSInfo) - - // localProxyControlAPI is a special kind of proxy. It is only wired up - // to receive requests from a trusted local socket, and these requests - // don't use TLS, therefore the requests it handles locally should - // bypass authorization. When it proxies, it sends them as requests from - // this manager rather than forwarded requests (it has no TLS - // information to put in the metadata map). + proxyDispatcherAPI := api.NewRaftProxyDispatcherServer(authenticatedDispatcherAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyCAAPI := api.NewRaftProxyCAServer(authenticatedCAAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyNodeCAAPI := api.NewRaftProxyNodeCAServer(authenticatedNodeCAAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyRaftMembershipAPI := api.NewRaftProxyRaftMembershipServer(authenticatedRaftMembershipAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(authenticatedResourceAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + proxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(authenticatedLogBrokerAPI, m.raftNode, nil, ca.WithMetadataForwardTLSInfo) + + // The following local proxies are only wired up to receive requests + // from a trusted local socket, and these requests don't use TLS, + // therefore the requests they handle locally should bypass + // authorization. When requests are proxied from these servers, they + // are sent as requests from this manager rather than forwarded + // requests (it has no TLS information to put in the metadata map). forwardAsOwnRequest := func(ctx context.Context) (context.Context, error) { return ctx, nil } - localProxyControlAPI := api.NewRaftProxyControlServer(baseControlAPI, m.raftNode, forwardAsOwnRequest) - localProxyLogsAPI := api.NewRaftProxyLogsServer(m.logbroker, m.raftNode, forwardAsOwnRequest) - localCAAPI := api.NewRaftProxyCAServer(m.caserver, m.raftNode, forwardAsOwnRequest) + handleRequestLocally := func(ctx context.Context) (context.Context, error) { + var remoteAddr string + if m.config.RemoteAPI.AdvertiseAddr != "" { + remoteAddr = m.config.RemoteAPI.AdvertiseAddr + } else { + remoteAddr = m.config.RemoteAPI.ListenAddr + } + + creds := m.config.SecurityConfig.ClientTLSCreds + + nodeInfo := ca.RemoteNodeInfo{ + Roles: []string{creds.Role()}, + Organization: creds.Organization(), + NodeID: creds.NodeID(), + RemoteAddr: remoteAddr, + } + + return context.WithValue(ctx, ca.LocalRequestKey, nodeInfo), nil + } + localProxyControlAPI := api.NewRaftProxyControlServer(baseControlAPI, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyLogsAPI := api.NewRaftProxyLogsServer(m.logbroker, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyDispatcherAPI := api.NewRaftProxyDispatcherServer(m.dispatcher, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyCAAPI := api.NewRaftProxyCAServer(m.caserver, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyNodeCAAPI := api.NewRaftProxyNodeCAServer(m.caserver, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyResourceAPI := api.NewRaftProxyResourceAllocatorServer(baseResourceAPI, m.raftNode, handleRequestLocally, forwardAsOwnRequest) + localProxyLogBrokerAPI := api.NewRaftProxyLogBrokerServer(m.logbroker, m.raftNode, handleRequestLocally, forwardAsOwnRequest) // Everything registered on m.server should be an authenticated // wrapper, or a proxy wrapping an authenticated wrapper! @@ -362,7 +385,11 @@ func (m *Manager) Run(parent context.Context) error { api.RegisterControlServer(m.localserver, localProxyControlAPI) api.RegisterLogsServer(m.localserver, localProxyLogsAPI) api.RegisterHealthServer(m.localserver, localHealthServer) - api.RegisterCAServer(m.localserver, localCAAPI) + api.RegisterDispatcherServer(m.localserver, localProxyDispatcherAPI) + api.RegisterCAServer(m.localserver, localProxyCAAPI) + api.RegisterNodeCAServer(m.localserver, localProxyNodeCAAPI) + api.RegisterResourceAllocatorServer(m.localserver, localProxyResourceAPI) + api.RegisterLogBrokerServer(m.localserver, localProxyLogBrokerAPI) healthServer.SetServingStatus("Raft", api.HealthCheckResponse_NOT_SERVING) localHealthServer.SetServingStatus("ControlAPI", api.HealthCheckResponse_NOT_SERVING) diff --git a/protobuf/plugin/raftproxy/raftproxy.go b/protobuf/plugin/raftproxy/raftproxy.go index 931dfdf23b..bb8582113b 100644 --- a/protobuf/plugin/raftproxy/raftproxy.go +++ b/protobuf/plugin/raftproxy/raftproxy.go @@ -27,12 +27,12 @@ func (g *raftProxyGen) genProxyStruct(s *descriptor.ServiceDescriptorProto) { g.gen.P("type " + serviceTypeName(s) + " struct {") g.gen.P("\tlocal " + s.GetName() + "Server") g.gen.P("\tconnSelector raftselector.ConnProvider") - g.gen.P("\tctxMods []func(context.Context)(context.Context, error)") + g.gen.P("\tlocalCtxMods, remoteCtxMods []func(context.Context)(context.Context, error)") g.gen.P("}") } func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) { - g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, ctxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {") + g.gen.P("func NewRaftProxy" + s.GetName() + "Server(local " + s.GetName() + "Server, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context)(context.Context, error)) " + s.GetName() + "Server {") g.gen.P(`redirectChecker := func(ctx context.Context)(context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -49,21 +49,27 @@ func (g *raftProxyGen) genProxyConstructor(s *descriptor.ServiceDescriptorProto) md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context)(context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context)(context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context)(context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context)(context.Context, error){localCtxMod} + } `) g.gen.P("return &" + serviceTypeName(s) + `{ local: local, connSelector: connSelector, - ctxMods: mods, + localCtxMods: localMods, + remoteCtxMods: remoteMods, }`) g.gen.P("}") } func (g *raftProxyGen) genRunCtxMods(s *descriptor.ServiceDescriptorProto) { - g.gen.P("func (p *" + serviceTypeName(s) + `) runCtxMods(ctx context.Context) (context.Context, error) { + g.gen.P("func (p *" + serviceTypeName(s) + `) runCtxMods(ctx context.Context, ctxMods []func(context.Context)(context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -91,18 +97,43 @@ func sigPrefix(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescrip return "func (p *" + serviceTypeName(s) + ") " + m.GetName() + "(" } +func (g *raftProxyGen) genStreamWrapper(streamType string) { + // Generate stream wrapper that returns a modified context + g.gen.P(`type ` + streamType + `Wrapper struct { + ` + streamType + ` + ctx context.Context +} +`) + g.gen.P(`func (s ` + streamType + `Wrapper) Context() context.Context { + return s.ctx +} +`) +} + func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { - g.gen.P(sigPrefix(s, m) + "stream " + s.GetName() + "_" + m.GetName() + "Server) error {") - g.gen.P(` + streamType := s.GetName() + "_" + m.GetName() + "Server" + + // Generate stream wrapper that returns a modified context + g.genStreamWrapper(streamType) + + g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.` + m.GetName() + `(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := ` + streamType + `Wrapper{ + ` + streamType + `: stream, + ctx: ctx, + } + return p.local.` + m.GetName() + `(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err }`) @@ -135,17 +166,28 @@ func (g *raftProxyGen) genClientStreamingMethod(s *descriptor.ServiceDescriptorP } func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { - g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + s.GetName() + "_" + m.GetName() + "Server) error {") - g.gen.P(` + streamType := s.GetName() + "_" + m.GetName() + "Server" + + g.genStreamWrapper(streamType) + + g.gen.P(sigPrefix(s, m) + "r *" + getInputTypeName(m) + ", stream " + streamType + `) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.` + m.GetName() + `(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := ` + streamType + `Wrapper{ + ` + streamType + `: stream, + ctx: ctx, + } + return p.local.` + m.GetName() + `(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err }`) @@ -172,17 +214,28 @@ func (g *raftProxyGen) genServerStreamingMethod(s *descriptor.ServiceDescriptorP } func (g *raftProxyGen) genClientServerStreamingMethod(s *descriptor.ServiceDescriptorProto, m *descriptor.MethodDescriptorProto) { - g.gen.P(sigPrefix(s, m) + "stream " + s.GetName() + "_" + m.GetName() + "Server) error {") - g.gen.P(` + streamType := s.GetName() + "_" + m.GetName() + "Server" + + g.genStreamWrapper(streamType) + + g.gen.P(sigPrefix(s, m) + "stream " + streamType + `) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.` + m.GetName() + `(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := ` + streamType + `Wrapper{ + ` + streamType + `: stream, + ctx: ctx, + } + return p.local.` + m.GetName() + `(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err }`) @@ -231,11 +284,15 @@ func (g *raftProxyGen) genSimpleMethod(s *descriptor.ServiceDescriptorProto, m * conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.` + m.GetName() + `(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err }`) diff --git a/protobuf/plugin/raftproxy/test/raftproxy_test.go b/protobuf/plugin/raftproxy/test/raftproxy_test.go index 2f4e6fd364..3dd8990661 100644 --- a/protobuf/plugin/raftproxy/test/raftproxy_test.go +++ b/protobuf/plugin/raftproxy/test/raftproxy_test.go @@ -51,7 +51,7 @@ func TestSimpleRedirect(t *testing.T) { cluster := &mockCluster{conn: conn} forwardAsOwnRequest := func(ctx context.Context) (context.Context, error) { return ctx, nil } - api := NewRaftProxyRouteGuideServer(testRouteGuide{}, cluster, forwardAsOwnRequest) + api := NewRaftProxyRouteGuideServer(testRouteGuide{}, cluster, nil, forwardAsOwnRequest) srv := grpc.NewServer() RegisterRouteGuideServer(srv, api) go srv.Serve(l) diff --git a/protobuf/plugin/raftproxy/test/service.pb.go b/protobuf/plugin/raftproxy/test/service.pb.go index 1ebe84bcfb..9285ec8964 100644 --- a/protobuf/plugin/raftproxy/test/service.pb.go +++ b/protobuf/plugin/raftproxy/test/service.pb.go @@ -906,12 +906,12 @@ func encodeVarintService(data []byte, offset int, v uint64) int { } type raftProxyRouteGuideServer struct { - local RouteGuideServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local RouteGuideServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyRouteGuideServer(local RouteGuideServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) RouteGuideServer { +func NewRaftProxyRouteGuideServer(local RouteGuideServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) RouteGuideServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -928,18 +928,24 @@ func NewRaftProxyRouteGuideServer(local RouteGuideServer, connSelector raftselec md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyRouteGuideServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyRouteGuideServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyRouteGuideServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -976,11 +982,15 @@ func (p *raftProxyRouteGuideServer) GetFeature(ctx context.Context, r *Point) (* conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.GetFeature(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err } @@ -1002,17 +1012,33 @@ func (p *raftProxyRouteGuideServer) GetFeature(ctx context.Context, r *Point) (* return resp, err } -func (p *raftProxyRouteGuideServer) ListFeatures(r *Rectangle, stream RouteGuide_ListFeaturesServer) error { +type RouteGuide_ListFeaturesServerWrapper struct { + RouteGuide_ListFeaturesServer + ctx context.Context +} +func (s RouteGuide_ListFeaturesServerWrapper) Context() context.Context { + return s.ctx +} + +func (p *raftProxyRouteGuideServer) ListFeatures(r *Rectangle, stream RouteGuide_ListFeaturesServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.ListFeatures(r, stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := RouteGuide_ListFeaturesServerWrapper{ + RouteGuide_ListFeaturesServer: stream, + ctx: ctx, + } + return p.local.ListFeatures(r, streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1037,17 +1063,33 @@ func (p *raftProxyRouteGuideServer) ListFeatures(r *Rectangle, stream RouteGuide return nil } -func (p *raftProxyRouteGuideServer) RecordRoute(stream RouteGuide_RecordRouteServer) error { +type RouteGuide_RecordRouteServerWrapper struct { + RouteGuide_RecordRouteServer + ctx context.Context +} + +func (s RouteGuide_RecordRouteServerWrapper) Context() context.Context { + return s.ctx +} +func (p *raftProxyRouteGuideServer) RecordRoute(stream RouteGuide_RecordRouteServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.RecordRoute(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := RouteGuide_RecordRouteServerWrapper{ + RouteGuide_RecordRouteServer: stream, + ctx: ctx, + } + return p.local.RecordRoute(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1078,17 +1120,33 @@ func (p *raftProxyRouteGuideServer) RecordRoute(stream RouteGuide_RecordRouteSer return stream.SendAndClose(reply) } -func (p *raftProxyRouteGuideServer) RouteChat(stream RouteGuide_RouteChatServer) error { +type RouteGuide_RouteChatServerWrapper struct { + RouteGuide_RouteChatServer + ctx context.Context +} + +func (s RouteGuide_RouteChatServerWrapper) Context() context.Context { + return s.ctx +} +func (p *raftProxyRouteGuideServer) RouteChat(stream RouteGuide_RouteChatServer) error { ctx := stream.Context() conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { - return p.local.RouteChat(stream) + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return err + } + streamWrapper := RouteGuide_RouteChatServerWrapper{ + RouteGuide_RouteChatServer: stream, + ctx: ctx, + } + return p.local.RouteChat(streamWrapper) } return err } - ctx, err = p.runCtxMods(ctx) + ctx, err = p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return err } @@ -1131,12 +1189,12 @@ func (p *raftProxyRouteGuideServer) RouteChat(stream RouteGuide_RouteChatServer) } type raftProxyHealthServer struct { - local HealthServer - connSelector raftselector.ConnProvider - ctxMods []func(context.Context) (context.Context, error) + local HealthServer + connSelector raftselector.ConnProvider + localCtxMods, remoteCtxMods []func(context.Context) (context.Context, error) } -func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, ctxMod func(context.Context) (context.Context, error)) HealthServer { +func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.ConnProvider, localCtxMod, remoteCtxMod func(context.Context) (context.Context, error)) HealthServer { redirectChecker := func(ctx context.Context) (context.Context, error) { s, ok := transport.StreamFromContext(ctx) if !ok { @@ -1153,18 +1211,24 @@ func NewRaftProxyHealthServer(local HealthServer, connSelector raftselector.Conn md["redirect"] = append(md["redirect"], addr) return metadata.NewContext(ctx, md), nil } - mods := []func(context.Context) (context.Context, error){redirectChecker} - mods = append(mods, ctxMod) + remoteMods := []func(context.Context) (context.Context, error){redirectChecker} + remoteMods = append(remoteMods, remoteCtxMod) + + var localMods []func(context.Context) (context.Context, error) + if localCtxMod != nil { + localMods = []func(context.Context) (context.Context, error){localCtxMod} + } return &raftProxyHealthServer{ - local: local, - connSelector: connSelector, - ctxMods: mods, + local: local, + connSelector: connSelector, + localCtxMods: localMods, + remoteCtxMods: remoteMods, } } -func (p *raftProxyHealthServer) runCtxMods(ctx context.Context) (context.Context, error) { +func (p *raftProxyHealthServer) runCtxMods(ctx context.Context, ctxMods []func(context.Context) (context.Context, error)) (context.Context, error) { var err error - for _, mod := range p.ctxMods { + for _, mod := range ctxMods { ctx, err = mod(ctx) if err != nil { return ctx, err @@ -1201,11 +1265,15 @@ func (p *raftProxyHealthServer) Check(ctx context.Context, r *HealthCheckRequest conn, err := p.connSelector.LeaderConn(ctx) if err != nil { if err == raftselector.ErrIsLeader { + ctx, err = p.runCtxMods(ctx, p.localCtxMods) + if err != nil { + return nil, err + } return p.local.Check(ctx, r) } return nil, err } - modCtx, err := p.runCtxMods(ctx) + modCtx, err := p.runCtxMods(ctx, p.remoteCtxMods) if err != nil { return nil, err }