Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions proxy/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ type ModelConfig struct {
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`

// Limit concurrency of HTTP requests to process
ConcurrencyLimit int `yaml:"concurrencyLimit"`
}

func (m *ModelConfig) SanitizedCommand() ([]string, error) {
Expand Down
20 changes: 20 additions & 0 deletions proxy/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,19 @@ type Process struct {
// for managing shutdown state
shutdownCtx context.Context
shutdownCancel context.CancelFunc

// for managing concurrency limits
concurrencyLimitSemaphore chan struct{}
}

func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background())
concurrentLimit := 10
if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit
} else {
proxyLogger.Debugf("Concurrency limit for model %s not set, defaulting to 10", ID)
}
return &Process{
ID: ID,
config: config,
Expand All @@ -73,6 +82,9 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,

// concurrency limit
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
}
}

Expand Down Expand Up @@ -417,6 +429,14 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
return
}

select {
case p.concurrencyLimitSemaphore <- struct{}{}:
defer func() { <-p.concurrencyLimitSemaphore }()
default:
http.Error(w, "Too many requests", http.StatusTooManyRequests)
return
}

p.inFlightRequests.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
Expand Down
32 changes: 32 additions & 0 deletions proxy/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,35 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed)
}

func TestProcess_ConcurrencyLimit(t *testing.T) {
if testing.Short() {
t.Skip("skipping long concurrency limit test")
}

expectedMessage := "concurrency_limit_test"
config := getTestSimpleResponderConfig(expectedMessage)

// only allow 1 concurrent request at a time
config.ConcurrencyLimit = 1

process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
defer process.Stop()

// launch a goroutine first to take up the semaphore
go func() {
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req1)
assert.Equal(t, http.StatusOK, w.Code)
}()

// let the goroutine start
<-time.After(time.Millisecond * 25)

denied := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, denied)
assert.Equal(t, http.StatusTooManyRequests, w.Code)
}