Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 88 additions & 25 deletions proxy/metrics_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (mp *metricsMonitor) wrapHandler(
tm := TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
DurationMs: int(time.Since(recorder.StartTime()).Milliseconds()),
DurationMs: int(recorder.Timings().totalDuration().Milliseconds()),
}

body := recorder.body.Bytes()
Expand All @@ -241,27 +241,18 @@ func (mp *metricsMonitor) wrapHandler(
}
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if parsed, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
if parsed, err := processStreamingResponse(modelID, request.URL.Path, recorder.Timings(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsed
}
} else {
if gjson.ValidBytes(body) {
parsed := gjson.ParseBytes(body)
usage := parsed.Get("usage")
timings := parsed.Get("timings")

// extract timings for infill - response is an array, timings are in the last element
// see #463
if strings.HasPrefix(request.URL.Path, "/infill") {
if arr := parsed.Array(); len(arr) > 0 {
timings = arr[len(arr)-1].Get("timings")
}
}
usage, timings := findMetricsPayload(parsed, request.URL.Path)

if usage.Exists() || timings.Exists() {
if parsedMetrics, err := parseMetrics(modelID, recorder.StartTime(), usage, timings); err != nil {
if parsedMetrics, err := parseMetrics(modelID, recorder.Timings(), usage, timings, false); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s, recording minimal metrics", err, request.URL.Path)
} else {
tm = parsedMetrics
Expand Down Expand Up @@ -307,7 +298,7 @@ func (mp *metricsMonitor) wrapHandler(
return nil
}

func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
func processStreamingResponse(modelID, reqPath string, timingInfo responseTimingInfo, body []byte) (TokenMetrics, error) {
// Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split.

Expand Down Expand Up @@ -347,19 +338,51 @@ func processStreamingResponse(modelID string, start time.Time, body []byte) (Tok

if gjson.ValidBytes(data) {
parsed := gjson.ParseBytes(data)
usage := parsed.Get("usage")
timings := parsed.Get("timings")
usage, timings := findMetricsPayload(parsed, reqPath)

if usage.Exists() || timings.Exists() {
return parseMetrics(modelID, start, usage, timings)
return parseMetrics(modelID, timingInfo, usage, timings, true)
}
}
}

return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
}

func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result) (TokenMetrics, error) {
func findMetricsPayload(parsed gjson.Result, reqPath string) (gjson.Result, gjson.Result) {
candidates := []gjson.Result{parsed}

if data := parsed.Get("data"); data.Exists() {
candidates = append(candidates, data)
}
if response := parsed.Get("response"); response.Exists() {
candidates = append(candidates, response)
}
if response := parsed.Get("data.response"); response.Exists() {
candidates = append(candidates, response)
}

for _, candidate := range candidates {
usage := candidate.Get("usage")
timings := candidate.Get("timings")

// extract timings for infill - response is an array, timings are in the last element
// see #463
if strings.HasPrefix(reqPath, "/infill") {
if arr := candidate.Array(); len(arr) > 0 {
timings = arr[len(arr)-1].Get("timings")
}
}

if usage.Exists() || timings.Exists() {
return usage, timings
}
}

return gjson.Result{}, gjson.Result{}
}

func parseMetrics(modelID string, timingInfo responseTimingInfo, usage, timings gjson.Result, allowFallback bool) (TokenMetrics, error) {
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
Expand All @@ -368,7 +391,7 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(start).Milliseconds())
durationMs := int(timingInfo.totalDuration().Milliseconds())

if usage.Exists() {
if pt := usage.Get("prompt_tokens"); pt.Exists() {
Expand Down Expand Up @@ -402,6 +425,10 @@ func parseMetrics(modelID string, start time.Time, usage, timings gjson.Result)
if cachedValue := timings.Get("cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
} else if allowFallback {
if generationDuration := timingInfo.generationDuration(); generationDuration > 0 && outputTokens > 1 {
tokensPerSecond = float64(outputTokens-1) / generationDuration.Seconds()
}
}

return TokenMetrics{
Expand Down Expand Up @@ -439,9 +466,11 @@ func decompressBody(body []byte, encoding string) ([]byte, error) {
// while also capturing it in a buffer for later processing
type responseBodyCopier struct {
gin.ResponseWriter
body *bytes.Buffer
tee io.Writer
start time.Time
body *bytes.Buffer
tee io.Writer
requestStart time.Time
firstWrite time.Time
lastWrite time.Time
}

func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
Expand All @@ -450,13 +479,16 @@ func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
ResponseWriter: w,
body: bodyBuffer,
tee: io.MultiWriter(w, bodyBuffer),
requestStart: time.Now(),
}
}

func (w *responseBodyCopier) Write(b []byte) (int, error) {
if w.start.IsZero() {
w.start = time.Now()
now := time.Now()
if w.firstWrite.IsZero() {
w.firstWrite = now
}
w.lastWrite = now

// Single write operation that writes to both the response and buffer
return w.tee.Write(b)
Expand All @@ -471,7 +503,38 @@ func (w *responseBodyCopier) Header() http.Header {
}

func (w *responseBodyCopier) StartTime() time.Time {
return w.start
return w.firstWrite
}

type responseTimingInfo struct {
requestStart time.Time
firstWrite time.Time
lastWrite time.Time
}

func (w *responseBodyCopier) Timings() responseTimingInfo {
return responseTimingInfo{
requestStart: w.requestStart,
firstWrite: w.firstWrite,
lastWrite: w.lastWrite,
}
}

func (t responseTimingInfo) totalDuration() time.Duration {
if !t.requestStart.IsZero() {
if !t.lastWrite.IsZero() {
return t.lastWrite.Sub(t.requestStart)
}
return time.Since(t.requestStart)
}
return 0
}

func (t responseTimingInfo) generationDuration() time.Duration {
if t.firstWrite.IsZero() || t.lastWrite.IsZero() {
return 0
}
return t.lastWrite.Sub(t.firstWrite)
}

// sensitiveHeaders lists headers that should be redacted in captures
Expand Down
157 changes: 157 additions & 0 deletions proxy/metrics_monitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,74 @@ func TestMetricsMonitor_WrapHandler(t *testing.T) {
assert.Equal(t, 50, metrics[0].OutputTokens)
})

t.Run("successful responses request with input and output usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)

responseBody := `{
"object": "response",
"output": [{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Hello"}]
}],
"usage": {
"input_tokens": 120,
"output_tokens": 45
}
}`

nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
time.Sleep(10 * time.Millisecond)
return nil
}

req := httptest.NewRequest("POST", "/v1/responses", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)

err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)

metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 120, metrics[0].InputTokens)
assert.Equal(t, 45, metrics[0].OutputTokens)
assert.Equal(t, -1.0, metrics[0].PromptPerSecond)
assert.Equal(t, -1.0, metrics[0].TokensPerSecond)
})

t.Run("chunked non-streaming responses request does not estimate generation speed", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)

nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"object":"response","usage":{"input_tokens":120,`))
time.Sleep(15 * time.Millisecond)
_, _ = w.Write([]byte(`"output_tokens":45,"total_tokens":165},`))
time.Sleep(15 * time.Millisecond)
_, _ = w.Write([]byte(`"output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"Hello"}]}]}`))
return nil
}

req := httptest.NewRequest("POST", "/v1/responses", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)

err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)

metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 120, metrics[0].InputTokens)
assert.Equal(t, 45, metrics[0].OutputTokens)
assert.Equal(t, -1.0, metrics[0].TokensPerSecond)
})

t.Run("successful request with timings data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)

Expand Down Expand Up @@ -679,6 +747,95 @@ data: [DONE]
assert.Equal(t, 50, metrics[0].OutputTokens)
})

t.Run("finds metrics in OpenAI Responses completion event", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)

responseBody := `data: {"event":"response.created","data":{"type":"response.created","response":{"id":"resp_123","object":"response","status":"in_progress"}}}

data: {"event":"response.output_text.delta","data":{"type":"response.output_text.delta","item_id":"msg_123","delta":"Hello"}}

data: {"event":"response.completed","data":{"type":"response.completed","response":{"id":"resp_123","object":"response","status":"completed","usage":{"input_tokens":80,"output_tokens":25,"total_tokens":105}}}}

data: [DONE]

`

nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}

req := httptest.NewRequest("POST", "/v1/responses", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)

err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)

metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 80, metrics[0].InputTokens)
assert.Equal(t, 25, metrics[0].OutputTokens)
})

t.Run("estimates prompt and generation speed for streamed Responses events", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)

nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
time.Sleep(15 * time.Millisecond)
_, _ = w.Write([]byte("data: {\"event\":\"response.output_text.delta\",\"data\":{\"type\":\"response.output_text.delta\",\"item_id\":\"msg_123\",\"delta\":\"Hello\"}}\n\n"))
time.Sleep(20 * time.Millisecond)
_, _ = w.Write([]byte("data: {\"event\":\"response.completed\",\"data\":{\"type\":\"response.completed\",\"response\":{\"id\":\"resp_123\",\"object\":\"response\",\"status\":\"completed\",\"usage\":{\"input_tokens\":80,\"output_tokens\":25,\"total_tokens\":105}}}}\n\n"))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
return nil
}

req := httptest.NewRequest("POST", "/v1/responses", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)

err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)

metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 80, metrics[0].InputTokens)
assert.Equal(t, 25, metrics[0].OutputTokens)
assert.Equal(t, -1.0, metrics[0].PromptPerSecond)
assert.Greater(t, metrics[0].TokensPerSecond, 0.0)
})

t.Run("single write fallback leaves generation speed unknown", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)

nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
time.Sleep(15 * time.Millisecond)
_, _ = w.Write([]byte("data: {\"event\":\"response.completed\",\"data\":{\"type\":\"response.completed\",\"response\":{\"id\":\"resp_123\",\"object\":\"response\",\"status\":\"completed\",\"usage\":{\"input_tokens\":8,\"output_tokens\":1,\"total_tokens\":9}}}}\n\n"))
_, _ = w.Write([]byte("data: [DONE]\n\n"))
return nil
}

req := httptest.NewRequest("POST", "/v1/responses", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)

err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)

metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 8, metrics[0].InputTokens)
assert.Equal(t, 1, metrics[0].OutputTokens)
assert.Equal(t, -1.0, metrics[0].PromptPerSecond)
assert.Equal(t, -1.0, metrics[0].TokensPerSecond)
})

t.Run("handles streaming with no valid JSON records minimal metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10, 0)

Expand Down
Loading
Loading