Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions go/Tiltfile
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ if start_api:
resource_deps=api_deps,
labels=['unkey'],
auto_init=True,
trigger_mode=TRIGGER_MODE_AUTO if debug_mode else TRIGGER_MODE_MANUAL
trigger_mode=TRIGGER_MODE_MANUAL if debug_mode else TRIGGER_MODE_AUTO
)

# Gateway service (1 replica)
Expand Down Expand Up @@ -204,7 +204,7 @@ if start_gw:
resource_deps=gw_deps,
labels=['unkey'],
auto_init=True,
trigger_mode=TRIGGER_MODE_AUTO if debug_mode else TRIGGER_MODE_MANUAL
trigger_mode=TRIGGER_MODE_MANUAL if debug_mode else TRIGGER_MODE_AUTO
)

# Ctrl service (1 replica)
Expand Down Expand Up @@ -238,7 +238,7 @@ if start_ctrl:
resource_deps=ctrl_deps,
labels=['unkey'],
auto_init=True,
trigger_mode=TRIGGER_MODE_AUTO if debug_mode else TRIGGER_MODE_MANUAL
trigger_mode=TRIGGER_MODE_MANUAL if debug_mode else TRIGGER_MODE_AUTO
)

# Metald service (1 replica)
Expand Down
2 changes: 1 addition & 1 deletion go/apps/api/routes/chproxy_metrics/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {

// Buffer all events to ClickHouse
for _, event := range events {
h.ClickHouse.BufferApiRequest(event)
h.ClickHouse.BufferRequest(event)
}

return s.JSON(http.StatusOK, map[string]string{"status": "OK"})
Expand Down
5 changes: 0 additions & 5 deletions go/apps/api/routes/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,12 @@ import (
"github.com/unkeyed/unkey/go/internal/services/ratelimit"
"github.com/unkeyed/unkey/go/internal/services/usagelimiter"
"github.com/unkeyed/unkey/go/pkg/clickhouse"
"github.com/unkeyed/unkey/go/pkg/clickhouse/schema"
"github.com/unkeyed/unkey/go/pkg/db"
"github.com/unkeyed/unkey/go/pkg/otel/logging"
"github.com/unkeyed/unkey/go/pkg/vault"
"github.com/unkeyed/unkey/go/pkg/zen/validation"
)

type EventBuffer interface {
BufferApiRequest(schema.ApiRequestV1)
}

type Services struct {
Logger logging.Logger
Database db.Database
Expand Down
15 changes: 12 additions & 3 deletions go/apps/gw/router/gateway_proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (h *Handler) Handle(ctx context.Context, sess *server.Session) error {
// Strip port from hostname for database lookup (Host header may include port)
hostname := routing.ExtractHostname(req)

config, err := h.RoutingService.GetConfig(ctx, hostname)
configWithWorkspace, err := h.RoutingService.GetConfig(ctx, hostname)
if err != nil {
return fault.Wrap(err,
fault.Code(codes.Gateway.Routing.ConfigNotFound.URN()),
Expand All @@ -44,6 +44,10 @@ func (h *Handler) Handle(ctx context.Context, sess *server.Session) error {
)
}

// Set workspace ID in session
sess.WorkspaceID = configWithWorkspace.WorkspaceID
config := configWithWorkspace.Config

// Handle request validation if configured
if h.Validator != nil {
err = h.Validator.Validate(ctx, sess, config)
Expand All @@ -68,8 +72,13 @@ func (h *Handler) Handle(ctx context.Context, sess *server.Session) error {
)
}

// Forward the request using the proxy service
err = h.Proxy.Forward(ctx, *targetURL, sess.ResponseWriter(), req)
// Forward the request using the proxy service with response capture
captureWriter, captureFunc := sess.CaptureResponseWriter()
err = h.Proxy.Forward(ctx, *targetURL, captureWriter, req)

// Capture the response data back to session after forwarding
captureFunc()

if err != nil {
return fault.Wrap(err,
fault.Code(codes.Gateway.Proxy.ProxyForwardFailed.URN()),
Expand Down
3 changes: 3 additions & 0 deletions go/apps/gw/server/middleware_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func WithErrorHandling(logger logging.Logger) Middleware {
return nil
}

// Store the original error for metrics logging
s.SetError(err)

// Get the error URN from the error
urn, ok := fault.GetCode(err)
if !ok {
Expand Down
70 changes: 34 additions & 36 deletions go/apps/gw/server/middleware_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ import (
"strings"

"github.com/unkeyed/unkey/go/pkg/clickhouse/schema"
"github.com/unkeyed/unkey/go/pkg/fault"
)

// EventBuffer defines the interface for buffering events to be sent to ClickHouse.
type EventBuffer interface {
BufferApiRequest(schema.ApiRequestV1)
BufferApiRequest(schema.ApiRequestV2)
}

// WithMetrics returns middleware that collects metrics about each request,
Expand Down Expand Up @@ -52,49 +51,48 @@ func WithMetrics(eventBuffer EventBuffer, region string) Middleware {

// Buffer to ClickHouse if enabled
// We don't need this ATM
// if eventBuffer != nil && s.r.Header.Get("X-Unkey-Metrics") != "disabled" {
// // Extract IP address from headers
// ips := strings.Split(s.r.Header.Get("X-Forwarded-For"), ",")
// ipAddress := ""
// if len(ips) > 0 {
// ipAddress = strings.TrimSpace(ips[0])
// }
// if ipAddress == "" {
// ipAddress = s.Location()
// }
if eventBuffer != nil {
// Extract IP address from headers
ips := strings.Split(s.r.Header.Get("X-Forwarded-For"), ",")
ipAddress := ""
if len(ips) > 0 {
ipAddress = strings.TrimSpace(ips[0])
}
if ipAddress == "" {
ipAddress = s.Location()
}

// eventBuffer.BufferApiRequest(schema.ApiRequestV1{
// WorkspaceID: s.WorkspaceID,
// RequestID: s.RequestID(),
// Time: s.startTime.UnixMilli(),
// Host: s.r.Host,
// Method: s.r.Method,
// Path: s.r.URL.Path,
// RequestHeaders: requestHeaders,
// RequestBody: string(s.requestBody),
// ResponseStatus: s.responseStatus,
// ResponseHeaders: responseHeaders,
// ResponseBody: string(s.responseBody),
// Error: getErrorMessage(nextErr),
// ServiceLatency: s.Latency().Milliseconds(),
// UserAgent: s.UserAgent(),
// IpAddress: ipAddress,
// Country: "",
// City: "",
// Colo: "",
// Continent: "",
// })
// }
eventBuffer.BufferApiRequest(schema.ApiRequestV2{
WorkspaceID: s.WorkspaceID,
RequestID: s.RequestID(),
Time: s.startTime.UnixMilli(),
Host: s.r.Host,
Method: s.r.Method,
Path: s.r.URL.Path,
RequestHeaders: requestHeaders,
RequestBody: string(s.requestBody),
ResponseStatus: int32(s.responseStatus),
ResponseHeaders: responseHeaders,
ResponseBody: string(s.responseBody),
Error: getErrorMessage(s.error),
ServiceLatency: s.Latency().Milliseconds(),
UserAgent: s.UserAgent(),
IpAddress: ipAddress,
Region: region,
})
}

return nextErr
}
}
}

// getErrorMessage extracts the user-facing error message if available.
// getErrorMessage extracts the internal error message for logging.
func getErrorMessage(err error) string {
if err == nil {
return ""
}
return fault.UserFacingMessage(err)

// Fallback for non-fault errors
return err.Error()
}
19 changes: 17 additions & 2 deletions go/apps/gw/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,24 @@ func (s *Server) WrapHandler(handler HandleFunc, middlewares []Middleware) http.
s.returnSession(sess)
}()

sess.init(w, r)
err := sess.init(w, r)
if err != nil {
// Apply default middleware chain for session initialization errors
handleFn := func(ctx context.Context, session *Session) error {
return err // Return the session init error
}

// Apply the same middleware chain
var wrappedHandler HandleFunc = handleFn
for i := len(middlewares) - 1; i >= 0; i-- {
wrappedHandler = middlewares[i](wrappedHandler)
}

_ = wrappedHandler(r.Context(), sess)
return
}

err := handle(r.Context(), sess)
err = handle(r.Context(), sess)
if err != nil {
// Error should have been handled by error middleware
// If we get here, something went wrong
Expand Down
85 changes: 84 additions & 1 deletion go/apps/gw/server/session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"bytes"
"encoding/json"
"io"
"log"
Expand Down Expand Up @@ -28,15 +29,43 @@ type Session struct {
requestBody []byte
responseStatus int
responseBody []byte
error error
}

// init initializes the session with a new request and response writer.
func (s *Session) init(w http.ResponseWriter, r *http.Request) {
func (s *Session) init(w http.ResponseWriter, r *http.Request) error {
s.requestID = uid.New(uid.RequestPrefix)
s.startTime = time.Now()
s.w = w
s.r = r
s.WorkspaceID = ""

// Read and cache the request body so metrics middleware can access it even on early errors.
// We need to replace r.Body with a fresh reader afterwards so other middleware
// can still read the body if necessary.
var err error
s.requestBody, err = io.ReadAll(s.r.Body)
closeErr := s.r.Body.Close()

// Handle read errors
if err != nil {
return fault.Wrap(err,
fault.Internal("unable to read request body"),
fault.Public("The request body could not be read."),
)
}

// Handle close error
if closeErr != nil {
return fault.Wrap(closeErr,
fault.Internal("failed to close request body"),
fault.Public("An error occurred processing the request."),
)
}

// Replace body with a fresh reader for subsequent middleware
s.r.Body = io.NopCloser(bytes.NewReader(s.requestBody))
return nil
}

// RequestID returns the unique request ID for this session.
Expand All @@ -59,6 +88,30 @@ func (s *Session) ResponseWriter() http.ResponseWriter {
return s.w
}

// CaptureResponseWriter returns a ResponseWriter that captures the response body.
// It returns the wrapper and a function to retrieve the captured data.
func (s *Session) CaptureResponseWriter() (http.ResponseWriter, func()) {
wrapper := &captureResponseWriter{
ResponseWriter: s.w,
statusCode: http.StatusOK, // Default to 200 if not set
}

// Return a function to store captured data back in session
capture := func() {
s.responseStatus = wrapper.statusCode
s.responseBody = wrapper.body
}

return wrapper, capture
}

// SetError stores the error for logging purposes.
func (s *Session) SetError(err error) {
if s.error == nil {
s.error = err
}
}

// UserAgent returns the User-Agent header from the request.
func (s *Session) UserAgent() string {
return s.r.UserAgent()
Expand Down Expand Up @@ -154,6 +207,7 @@ func (s *Session) reset() {
s.requestBody = nil
s.responseStatus = 0
s.responseBody = nil
s.error = nil
}

// wrapResponseWriter wraps http.ResponseWriter to capture the status code.
Expand All @@ -180,3 +234,32 @@ func (w *wrapResponseWriter) Write(b []byte) (int, error) {

return w.ResponseWriter.Write(b)
}

// captureResponseWriter wraps http.ResponseWriter to capture the status code and response body.
type captureResponseWriter struct {
http.ResponseWriter
statusCode int
body []byte
written bool
}

func (w *captureResponseWriter) WriteHeader(code int) {
if w.written {
return // Already written, don't write again
}

w.statusCode = code
w.written = true
w.ResponseWriter.WriteHeader(code)
}

func (w *captureResponseWriter) Write(b []byte) (int, error) {
if !w.written {
w.WriteHeader(http.StatusOK)
}

// Capture the body
w.body = append(w.body, b...)

return w.ResponseWriter.Write(b)
}
6 changes: 3 additions & 3 deletions go/apps/gw/services/caches/caches.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"time"

validator "github.com/pb33f/libopenapi-validator"
partitionv1 "github.com/unkeyed/unkey/go/gen/proto/partition/v1"
"github.com/unkeyed/unkey/go/apps/gw/services/routing"
"github.com/unkeyed/unkey/go/pkg/cache"
"github.com/unkeyed/unkey/go/pkg/cache/middleware"
"github.com/unkeyed/unkey/go/pkg/clock"
Expand All @@ -19,7 +19,7 @@ import (
// Each field represents a specialized cache for a specific data entity.
type Caches struct {
// HostName -> Config
GatewayConfig cache.Cache[string, *partitionv1.GatewayConfig]
GatewayConfig cache.Cache[string, routing.ConfigWithWorkspace]

// DeploymentID -> OpenAPI Spec Validator
OpenAPISpec cache.Cache[string, validator.Validator]
Expand Down Expand Up @@ -72,7 +72,7 @@ type Config struct {
// // Use the caches
// key, err := caches.KeyByHash.Get(ctx, "some-hash")
func New(config Config) (Caches, error) {
gatewayConfig, err := cache.New(cache.Config[string, *partitionv1.GatewayConfig]{
gatewayConfig, err := cache.New(cache.Config[string, routing.ConfigWithWorkspace]{
Fresh: time.Second * 5,
Stale: time.Second * 30,
Logger: config.Logger,
Expand Down
6 changes: 3 additions & 3 deletions go/apps/gw/services/routing/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (

// Service handles gateway configuration lookup and VM selection.
type Service interface {
// GetTargetByHost finds gateway configuration based on the request host
GetConfig(ctx context.Context, host string) (*partitionv1.GatewayConfig, error)
// GetConfig finds gateway configuration and workspace ID based on the request host
GetConfig(ctx context.Context, host string) (*ConfigWithWorkspace, error)

// SelectVM picks an available VM from the gateway's VM list
SelectVM(ctx context.Context, config *partitionv1.GatewayConfig) (*url.URL, error)
Expand All @@ -27,6 +27,6 @@ type Config struct {
Logger logging.Logger
Clock clock.Clock

GatewayConfigCache cache.Cache[string, *partitionv1.GatewayConfig]
GatewayConfigCache cache.Cache[string, ConfigWithWorkspace]
VMCache cache.Cache[string, pdb.Vm]
}
Loading