Skip to content

Commit

Permalink
Fix the issue of not refreshing the token expiration time when reconn…
Browse files Browse the repository at this point in the history
…ecting (#394)

Fixes #393
  • Loading branch information
mstmdev authored Oct 30, 2024
1 parent f3aff20 commit 4ad495e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
4 changes: 4 additions & 0 deletions api/apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ type Client interface {
Monitor() (monitor.MonitorService_MonitorClient, error)
// IsClosed is connection closed of the current client
IsClosed(err error) bool
// IsUnauthenticated check whether the error is unauthorized
IsUnauthenticated(err error) bool
// SubscribeTask register a task client to the task server and wait to receive task
SubscribeTask(clientInfo *task.ClientInfo) (task.TaskService_SubscribeTaskClient, error)
// Login login to the server
Login() (err error)
}
24 changes: 12 additions & 12 deletions api/apiclient/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (c *client) Start() (err error) {
if err = c.connect(); err != nil {
return err
}
return c.login()
return c.Login()
}

func (c *client) connect() (err error) {
Expand Down Expand Up @@ -89,10 +89,10 @@ func (c *client) getInfo() (*info.FileServerInfo, error) {

func (c *client) GetInfo() (*info.FileServerInfo, error) {
fsi, err := c.getInfo()
if !c.needLogin(err) {
if !c.IsUnauthenticated(err) {
return fsi, err
}
if err = c.login(); err != nil {
if err = c.Login(); err != nil {
return nil, err
}
return c.getInfo()
Expand All @@ -104,10 +104,10 @@ func (c *client) monitor() (monitor.MonitorService_MonitorClient, error) {

func (c *client) Monitor() (monitor.MonitorService_MonitorClient, error) {
fsi, err := c.monitor()
if !c.needLogin(err) {
if !c.IsUnauthenticated(err) {
return fsi, err
}
if err = c.login(); err != nil {
if err = c.Login(); err != nil {
return nil, err
}
return c.monitor()
Expand All @@ -117,16 +117,20 @@ func (c *client) IsClosed(err error) bool {
return status.Code(err) == codes.Unavailable
}

func (c *client) IsUnauthenticated(err error) bool {
return status.Code(err) == codes.Unauthenticated
}

func (c *client) subscribeTask(clientInfo *task.ClientInfo) (task.TaskService_SubscribeTaskClient, error) {
return c.TaskServiceClient.SubscribeTask(context.Background(), clientInfo, grpc.PerRPCCredentials(c.creds))
}

func (c *client) SubscribeTask(clientInfo *task.ClientInfo) (task.TaskService_SubscribeTaskClient, error) {
rc, err := c.subscribeTask(clientInfo)
if !c.needLogin(err) {
if !c.IsUnauthenticated(err) {
return rc, err
}
if err = c.login(); err != nil {
if err = c.Login(); err != nil {
return nil, err
}
return c.subscribeTask(clientInfo)
Expand All @@ -144,11 +148,7 @@ func (c *client) getToken() (token string, err error) {
return reply.Token, nil
}

func (c *client) needLogin(err error) bool {
return status.Code(err) == codes.Unauthenticated
}

func (c *client) login() (err error) {
func (c *client) Login() (err error) {
token, err := c.getToken()
if err == nil {
oauth2Token := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})
Expand Down
9 changes: 9 additions & 0 deletions monitor/remote_client_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,18 @@ func (m *remoteClientMonitor) readMessage(st *atomic.Bool, wd wait.Done) {
nmc, err := m.client.Monitor()
if err == nil {
mc = nmc
m.logger.Info("monitor the remote server success")
}
return err
}, "monitor the remote server")
} else if m.client.IsUnauthenticated(err) {
if m.logger.ErrorIf(m.client.Login(), "re-login to remote server error") == nil {
m.logger.Info("re-login to remote server success")
if nmc, err := m.client.Monitor(); err == nil {
mc = nmc
m.logger.Info("monitor the remote server success")
}
}
}
} else {
m.messages.PushBack(msg)
Expand Down

0 comments on commit 4ad495e

Please sign in to comment.