diff --git a/pkg/networkservice/common/timeout/server.go b/pkg/networkservice/common/timeout/server.go index 01bd176da3..e8f2fa538f 100644 --- a/pkg/networkservice/common/timeout/server.go +++ b/pkg/networkservice/common/timeout/server.go @@ -29,13 +29,14 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" "github.com/networkservicemesh/sdk/pkg/tools/extend" + "github.com/networkservicemesh/sdk/pkg/tools/fifosync" "github.com/networkservicemesh/sdk/pkg/tools/log" - "github.com/networkservicemesh/sdk/pkg/tools/multiexecutor" ) type timeoutServer struct { + ctx context.Context connections timerMap - executor *multiexecutor.Executor + mutexGroup fifosync.MutexGroup } type timer struct { @@ -46,7 +47,7 @@ type timer struct { // NewServer - creates a new NetworkServiceServer chain element that implements timeout of expired connections. func NewServer(ctx context.Context) networkservice.NetworkServiceServer { return &timeoutServer{ - executor: multiexecutor.NewExecutor(ctx), + ctx: ctx, } } @@ -54,31 +55,33 @@ func (t *timeoutServer) Request(ctx context.Context, request *networkservice.Net logEntry := log.Entry(ctx).WithField("timeoutServer", "Request") connID := request.GetConnection().GetId() - <-t.executor.AsyncExec(connID, func() { - if timer, ok := t.connections.Load(connID); ok { - if !timer.timer.Stop() { - logEntry.Warnf("connection has been timed out, re requesting: %v", connID) - } - close(timer.stopCh) - t.connections.Delete(connID) - } - conn, err = next.Server(ctx).Request(ctx, request) - if err != nil { - return + t.mutexGroup.Lock(connID) + defer t.mutexGroup.Unlock(connID) + + if timer, ok := t.connections.Load(connID); ok { + if !timer.timer.Stop() { + logEntry.Warnf("connection has been timed out, re requesting: %v", connID) } + close(timer.stopCh) + t.connections.Delete(connID) + } - var timer *timer - timer, err = t.createTimer(ctx, conn) - if err != nil { - if _, closeErr := next.Server(ctx).Close(ctx, conn); closeErr != nil { - err = errors.Wrapf(err, "error attempting to close failed connection %v: %+v", connID, closeErr) - } - return + conn, err = next.Server(ctx).Request(ctx, request) + if err != nil { + return + } + + var timer *timer + timer, err = t.createTimer(ctx, conn) + if err != nil { + if _, closeErr := next.Server(ctx).Close(ctx, conn); closeErr != nil { + err = errors.Wrapf(err, "error attempting to close failed connection %v: %+v", connID, closeErr) } + return + } - t.connections.Store(connID, timer) - }) + t.connections.Store(connID, timer) return conn, err } @@ -92,35 +95,36 @@ func (t *timeoutServer) createTimer(ctx context.Context, conn *networkservice.Co } conn = conn.Clone() - ctx = extend.WithValuesFromContext(context.Background(), ctx) + timerCtx := extend.WithValuesFromContext(context.Background(), t.ctx) timer := &timer{ stopCh: make(chan struct{}, 1), } timer.timer = time.AfterFunc(time.Until(expireTime), func() { - t.executor.AsyncExec(conn.GetId(), func() { - select { - case <-timer.stopCh: - logEntry.Warnf("timer has been already stopped: %v", conn.GetId()) - default: - if err := t.close(ctx, conn); err != nil { - logEntry.Errorf("failed to close timed out connection: %v %+v", conn.GetId(), err) - } + t.mutexGroup.Lock(conn.GetId()) + defer t.mutexGroup.Unlock(conn.GetId()) + + select { + case <-timer.stopCh: + logEntry.Warnf("timer has been already stopped: %v", conn.GetId()) + default: + if err := t.close(timerCtx, conn, next.Server(ctx)); err != nil { + logEntry.Errorf("failed to close timed out connection: %v %+v", conn.GetId(), err) } - }) + } }) return timer, nil } func (t *timeoutServer) Close(ctx context.Context, conn *networkservice.Connection) (_ *empty.Empty, err error) { - <-t.executor.AsyncExec(conn.GetId(), func() { - err = t.close(ctx, conn) - }) - return &empty.Empty{}, err + t.mutexGroup.Lock(conn.GetId()) + defer t.mutexGroup.Unlock(conn.GetId()) + + return &empty.Empty{}, t.close(ctx, conn, next.Server(ctx)) } -func (t *timeoutServer) close(ctx context.Context, conn *networkservice.Connection) error { +func (t *timeoutServer) close(ctx context.Context, conn *networkservice.Connection, nextServer networkservice.NetworkServiceServer) error { logEntry := log.Entry(ctx).WithField("timeoutServer", "close") timer, ok := t.connections.Load(conn.GetId()) @@ -133,6 +137,6 @@ func (t *timeoutServer) close(ctx context.Context, conn *networkservice.Connecti close(timer.stopCh) t.connections.Delete(conn.GetId()) - _, err := next.Server(ctx).Close(ctx, conn) + _, err := nextServer.Close(ctx, conn) return err } diff --git a/pkg/tools/multiexecutor/multi_executor.go b/pkg/tools/multiexecutor/multi_executor.go deleted file mode 100644 index d2459e402c..0000000000 --- a/pkg/tools/multiexecutor/multi_executor.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package multiexecutor provides serialize.Executor with multiple task queues matched by IDs -package multiexecutor - -import ( - "context" - "sync/atomic" - "time" - - "github.com/edwarnicke/serialize" -) - -const ( - cleanupTimeout = 10 * time.Millisecond -) - -// Executor is a serialize.Executor with multiple task queues matched by IDs -type Executor struct { - executor serialize.Executor - executors map[string]*executor - updated int32 -} - -type executor struct { - executor serialize.Executor - count int32 -} - -func NewExecutor(ctx context.Context) *Executor { - e := &Executor{ - executors: map[string]*executor{}, - } - - go e.cleanup(ctx) - - return e -} - -func (e *Executor) cleanup(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case <-time.After(cleanupTimeout): - <-e.executor.AsyncExec(func() { - if atomic.CompareAndSwapInt32(&e.updated, 1, 0) { - return - } - for id, ex := range e.executors { - if atomic.LoadInt32(&ex.count) == 0 { - delete(e.executors, id) - } - } - }) - } - } -} - -// AsyncExec starts task in the queue selected by the given ID (look at serialize.Executor) -func (e *Executor) AsyncExec(id string, f func()) <-chan struct{} { - ready := make(chan struct{}) - chanPtr := new(<-chan struct{}) - e.executor.AsyncExec(func() { - ex, ok := e.executors[id] - if !ok { - ex = &executor{} - e.executors[id] = ex - } - atomic.AddInt32(&ex.count, 1) - - *chanPtr = ex.executor.AsyncExec(func() { - f() - atomic.AddInt32(&ex.count, -1) - atomic.CompareAndSwapInt32(&e.updated, 0, 1) - }) - close(ready) - }) - - return wrapChanPtr(ready, chanPtr) -} - -func wrapChanPtr(ready <-chan struct{}, chanPtr *<-chan struct{}) <-chan struct{} { - rv := make(chan struct{}) - go func() { - <-ready - <-*chanPtr - close(rv) - }() - return rv -}