Skip to content

Commit 5b2e20e

Browse files
committed
Implemented workers-cap in msgpackrpc connection
1 parent 0fc59a6 commit 5b2e20e

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

internal/msgpackrouter/router.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func (r *Router) connectionLoop(conn io.ReadWriteCloser) {
8181
defer conn.Close()
8282

8383
var msgpackconn *msgpackrpc.Connection
84-
msgpackconn = msgpackrpc.NewConnection(conn, conn,
84+
msgpackconn = msgpackrpc.NewConnectionWithMaxWorkers(conn, conn,
8585
func(ctx context.Context, _ msgpackrpc.FunctionLogger, method string, params []any) (_result any, _err any) {
8686
// This handler is called when a request is received from the client
8787
slog.Debug("Received request", "method", method, "params", params)
@@ -160,6 +160,7 @@ func (r *Router) connectionLoop(conn io.ReadWriteCloser) {
160160
}
161161
slog.Error("Error in connection", "err", err)
162162
},
163+
r.sendQueueSize,
163164
)
164165

165166
msgpackconn.Run()

msgpackrpc/connection.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ type Connection struct {
5252
activeOutRequests map[MessageID]*outRequest
5353
activeOutRequestsMutex sync.Mutex
5454
lastOutRequestsIndex atomic.Uint32
55+
56+
workerSlots chan bool
5557
}
5658

5759
type inRequest struct {
@@ -79,8 +81,14 @@ type NotificationHandler func(logger FunctionLogger, method string, params []any
7981
// sending a request or notification.
8082
type ErrorHandler func(error)
8183

82-
// NewConnection starts a new
84+
// NewConnection creates a new MessagePack-RPC Connection handler.
8385
func NewConnection(in io.ReadCloser, out io.WriteCloser, requestHandler RequestHandler, notificationHandler NotificationHandler, errorHandler ErrorHandler) *Connection {
86+
return NewConnectionWithMaxWorkers(in, out, requestHandler, notificationHandler, errorHandler, 0)
87+
}
88+
89+
// NewConnectionWithMaxWorkers creates a new MessagePack-RPC Connection handler
90+
// with a specified maximum number of worker goroutines to handle incoming requests.
91+
func NewConnectionWithMaxWorkers(in io.ReadCloser, out io.WriteCloser, requestHandler RequestHandler, notificationHandler NotificationHandler, errorHandler ErrorHandler, maxWorkers int) *Connection {
8492
outEncoder := msgpack.NewEncoder(out)
8593
outEncoder.UseCompactInts(true)
8694
if requestHandler == nil {
@@ -109,9 +117,24 @@ func NewConnection(in io.ReadCloser, out io.WriteCloser, requestHandler RequestH
109117
activeOutRequests: map[MessageID]*outRequest{},
110118
logger: NullLogger{},
111119
}
120+
if maxWorkers > 0 {
121+
conn.workerSlots = make(chan bool, maxWorkers)
122+
}
112123
return conn
113124
}
114125

126+
func (c *Connection) startWorker(cb func()) {
127+
if c.workerSlots == nil {
128+
go cb()
129+
return
130+
}
131+
c.workerSlots <- true
132+
go func() {
133+
defer func() { <-c.workerSlots }()
134+
cb()
135+
}()
136+
}
137+
115138
func (c *Connection) SetLogger(l Logger) {
116139
c.loggerMutex.Lock()
117140
c.logger = l
@@ -215,7 +238,7 @@ func (c *Connection) handleIncomingRequest(id MessageID, method string, params [
215238
logger := c.logger.LogIncomingRequest(id, method, params)
216239
c.loggerMutex.Unlock()
217240

218-
go func() {
241+
c.startWorker(func() {
219242
reqResult, reqError := c.requestHandler(ctx, logger, method, params)
220243

221244
var existing *inRequest
@@ -238,7 +261,7 @@ func (c *Connection) handleIncomingRequest(id MessageID, method string, params [
238261
c.errorHandler(fmt.Errorf("error sending response: %w", err))
239262
c.Close()
240263
}
241-
}()
264+
})
242265
}
243266

244267
func (c *Connection) handleIncomingNotification(method string, params []any) {
@@ -261,7 +284,9 @@ func (c *Connection) handleIncomingNotification(method string, params []any) {
261284
logger := c.logger.LogIncomingNotification(method, params)
262285
c.loggerMutex.Unlock()
263286

264-
go c.notificationHandler(logger, method, params)
287+
c.startWorker(func() {
288+
c.notificationHandler(logger, method, params)
289+
})
265290
}
266291

267292
func (c *Connection) handleIncomingResponse(id MessageID, reqError any, reqResult any) {

0 commit comments

Comments
 (0)