From 4ad495e1df3dc9f6637a2fef87d5db4d9d27db6f Mon Sep 17 00:00:00 2001 From: mstmdev Date: Wed, 30 Oct 2024 10:50:53 +0800 Subject: [PATCH] Fix the issue of not refreshing the token expiration time when reconnecting (#394) Fixes #393 --- api/apiclient/client.go | 4 ++++ api/apiclient/grpc_client.go | 24 ++++++++++++------------ monitor/remote_client_monitor.go | 9 +++++++++ 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/api/apiclient/client.go b/api/apiclient/client.go index a81a17e4..d7451c51 100644 --- a/api/apiclient/client.go +++ b/api/apiclient/client.go @@ -18,6 +18,10 @@ type Client interface { Monitor() (monitor.MonitorService_MonitorClient, error) // IsClosed is connection closed of the current client IsClosed(err error) bool + // IsUnauthenticated check whether the error is unauthorized + IsUnauthenticated(err error) bool // SubscribeTask register a task client to the task server and wait to receive task SubscribeTask(clientInfo *task.ClientInfo) (task.TaskService_SubscribeTaskClient, error) + // Login login to the server + Login() (err error) } diff --git a/api/apiclient/grpc_client.go b/api/apiclient/grpc_client.go index f563a633..19827236 100644 --- a/api/apiclient/grpc_client.go +++ b/api/apiclient/grpc_client.go @@ -53,7 +53,7 @@ func (c *client) Start() (err error) { if err = c.connect(); err != nil { return err } - return c.login() + return c.Login() } func (c *client) connect() (err error) { @@ -89,10 +89,10 @@ func (c *client) getInfo() (*info.FileServerInfo, error) { func (c *client) GetInfo() (*info.FileServerInfo, error) { fsi, err := c.getInfo() - if !c.needLogin(err) { + if !c.IsUnauthenticated(err) { return fsi, err } - if err = c.login(); err != nil { + if err = c.Login(); err != nil { return nil, err } return c.getInfo() @@ -104,10 +104,10 @@ func (c *client) monitor() (monitor.MonitorService_MonitorClient, error) { func (c *client) Monitor() (monitor.MonitorService_MonitorClient, error) { fsi, err := c.monitor() - if !c.needLogin(err) { + if !c.IsUnauthenticated(err) { return fsi, err } - if err = c.login(); err != nil { + if err = c.Login(); err != nil { return nil, err } return c.monitor() @@ -117,16 +117,20 @@ func (c *client) IsClosed(err error) bool { return status.Code(err) == codes.Unavailable } +func (c *client) IsUnauthenticated(err error) bool { + return status.Code(err) == codes.Unauthenticated +} + func (c *client) subscribeTask(clientInfo *task.ClientInfo) (task.TaskService_SubscribeTaskClient, error) { return c.TaskServiceClient.SubscribeTask(context.Background(), clientInfo, grpc.PerRPCCredentials(c.creds)) } func (c *client) SubscribeTask(clientInfo *task.ClientInfo) (task.TaskService_SubscribeTaskClient, error) { rc, err := c.subscribeTask(clientInfo) - if !c.needLogin(err) { + if !c.IsUnauthenticated(err) { return rc, err } - if err = c.login(); err != nil { + if err = c.Login(); err != nil { return nil, err } return c.subscribeTask(clientInfo) @@ -144,11 +148,7 @@ func (c *client) getToken() (token string, err error) { return reply.Token, nil } -func (c *client) needLogin(err error) bool { - return status.Code(err) == codes.Unauthenticated -} - -func (c *client) login() (err error) { +func (c *client) Login() (err error) { token, err := c.getToken() if err == nil { oauth2Token := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token}) diff --git a/monitor/remote_client_monitor.go b/monitor/remote_client_monitor.go index 13985651..177e6f40 100644 --- a/monitor/remote_client_monitor.go +++ b/monitor/remote_client_monitor.go @@ -149,9 +149,18 @@ func (m *remoteClientMonitor) readMessage(st *atomic.Bool, wd wait.Done) { nmc, err := m.client.Monitor() if err == nil { mc = nmc + m.logger.Info("monitor the remote server success") } return err }, "monitor the remote server") + } else if m.client.IsUnauthenticated(err) { + if m.logger.ErrorIf(m.client.Login(), "re-login to remote server error") == nil { + m.logger.Info("re-login to remote server success") + if nmc, err := m.client.Monitor(); err == nil { + mc = nmc + m.logger.Info("monitor the remote server success") + } + } } } else { m.messages.PushBack(msg)