diff --git a/Dockerfile.sidecar b/Dockerfile.sidecar index 9c9e747e20..989fd13bed 100644 --- a/Dockerfile.sidecar +++ b/Dockerfile.sidecar @@ -17,6 +17,7 @@ RUN go mod download COPY cmd/pd-sidecar/main.go cmd/cmd.go COPY pkg/sidecar pkg/sidecar COPY pkg/common pkg/common +COPY pkg/telemetry pkg/telemetry # Build # the GOARCH has not a default value to allow the binary be built according to the host where the command diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 1952fcf30a..0638f6c4c6 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -32,15 +32,40 @@ import ( "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) func main() { + os.Exit(run()) +} + +func run() int { + ctx := ctrl.SetupSignalHandler() + + // Initialize tracing before creating any spans + shutdownTracing, err := telemetry.InitTracing(ctx) + if err != nil { + // Log error but don't fail - tracing is optional + ctrl.Log.Error(err, "Failed to initialize tracing") + } + if shutdownTracing != nil { + defer func() { + if err := shutdownTracing(ctx); err != nil { + ctrl.Log.Error(err, "Failed to shutdown tracing") + } + }() + } + // Register llm-d-inference-scheduler plugins plugins.RegisterAllPlugins() + // Note: GIE built-in plugins are automatically registered by the runner + // when it processes configuration in runner.parsePluginsConfiguration() + if err := runner.NewRunner(). WithCustomCollectors(metrics.GetCollectors()...). - Run(ctrl.SetupSignalHandler()); err != nil { - os.Exit(1) + Run(ctx); err != nil { + return 1 } + return 0 } diff --git a/cmd/pd-sidecar/main.go b/cmd/pd-sidecar/main.go index 8e9e6533ca..45927102e6 100644 --- a/cmd/pd-sidecar/main.go +++ b/cmd/pd-sidecar/main.go @@ -29,6 +29,7 @@ import ( "github.com/llm-d/llm-d-inference-scheduler/pkg/sidecar/proxy" "github.com/llm-d/llm-d-inference-scheduler/pkg/sidecar/version" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) var ( @@ -70,6 +71,20 @@ func main() { ctx := ctrl.SetupSignalHandler() log.IntoContext(ctx, logger) + // Initialize tracing before creating any spans + shutdownTracing, err := telemetry.InitTracing(ctx) + if err != nil { + // Log error but don't fail - tracing is optional + logger.Error(err, "Failed to initialize tracing") + } + if shutdownTracing != nil { + defer func() { + if err := shutdownTracing(ctx); err != nil { + logger.Error(err, "Failed to shutdown tracing") + } + }() + } + logger.Info("Proxy starting", "Built on", version.BuildRef, "From Git SHA", version.CommitSHA) // Validate connector diff --git a/go.mod b/go.mod index e29565d848..f0818e2013 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,11 @@ require ( github.com/openai/openai-go v1.12.0 github.com/prometheus/client_golang v1.23.2 github.com/stretchr/testify v1.11.1 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 + go.opentelemetry.io/otel v1.39.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 + go.opentelemetry.io/otel/sdk v1.39.0 + go.opentelemetry.io/otel/trace v1.39.0 golang.org/x/sync v0.19.0 google.golang.org/grpc v1.79.1 k8s.io/api v0.34.4 @@ -103,14 +108,9 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/xlab/treeprint v1.2.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect - go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 // indirect go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.39.0 // indirect go.opentelemetry.io/otel/metric v1.39.0 // indirect - go.opentelemetry.io/otel/sdk v1.39.0 // indirect - go.opentelemetry.io/otel/trace v1.39.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/pkg/common/common.go b/pkg/common/common.go index a6c50fe7cc..2fb6f45b42 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -4,6 +4,8 @@ //revive:disable:var-naming package common +import "net/url" + const ( // PrefillPodHeader is the header name used to indicate Prefill worker PrefillPodHeader = "x-prefiller-host-port" @@ -11,3 +13,13 @@ const ( // DataParallelPodHeader is the header name used to indicate the worker for Data Parallel DataParallelPodHeader = "x-data-parallel-host-port" ) + +// StripScheme removes the scheme from an endpoint URL, returning host:port. +// This is useful for gRPC clients that expect host:port format only. +func StripScheme(endpoint string) string { + u, err := url.Parse(endpoint) + if err != nil || u.Host == "" { + return endpoint // not a valid URL, return as-is + } + return u.Host +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go new file mode 100644 index 0000000000..dc20b36071 --- /dev/null +++ b/pkg/common/common_test.go @@ -0,0 +1,90 @@ +// Package common contains items common to both the +// EPP/Inference-Scheduler and the Routing Sidecar +// +//revive:disable:var-naming +package common + +import "testing" + +func TestStripScheme(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "http scheme", + input: "http://localhost:4317", + expected: "localhost:4317", + }, + { + name: "https scheme", + input: "https://localhost:4317", + expected: "localhost:4317", + }, + { + name: "no scheme", + input: "localhost:4317", + expected: "localhost:4317", + }, + { + name: "host only", + input: "localhost", + expected: "localhost", + }, + { + name: "http with domain", + input: "http://otel-collector.monitoring.svc.cluster.local:4317", + expected: "otel-collector.monitoring.svc.cluster.local:4317", + }, + { + name: "https with domain", + input: "https://otel-collector.monitoring.svc.cluster.local:4317", + expected: "otel-collector.monitoring.svc.cluster.local:4317", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "ip address with http", + input: "http://10.0.0.1:4317", + expected: "10.0.0.1:4317", + }, + { + name: "ip address with https", + input: "https://10.0.0.1:4317", + expected: "10.0.0.1:4317", + }, + { + name: "ip address without scheme", + input: "10.0.0.1:4317", + expected: "10.0.0.1:4317", + }, + { + name: "schemeless with double slash", + input: "//192.168.1.1:80", + expected: "192.168.1.1:80", + }, + { + name: "uppercase scheme", + input: "HTTP://localhost:4317", + expected: "localhost:4317", + }, + { + name: "port only", + input: ":9090", + expected: ":9090", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := StripScheme(tt.input) + if result != tt.expected { + t.Errorf("StripScheme(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} diff --git a/pkg/plugins/pre-request/pd_prerequest.go b/pkg/plugins/pre-request/pd_prerequest.go index c77fc700f8..9bfda5a4fb 100644 --- a/pkg/plugins/pre-request/pd_prerequest.go +++ b/pkg/plugins/pre-request/pd_prerequest.go @@ -7,11 +7,14 @@ import ( "fmt" "net" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) const ( @@ -67,17 +70,39 @@ func (p *PrefillHeaderHandler) WithName(name string) *PrefillHeaderHandler { } // PreRequest wires prefill SchedulerProfile result into a header to indicate prefill worker -func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *scheduling.LLMRequest, schedulingResult *scheduling.SchedulingResult) { +func (p *PrefillHeaderHandler) PreRequest(ctx context.Context, request *scheduling.LLMRequest, schedulingResult *scheduling.SchedulingResult) { + tracer := telemetry.Tracer() + _, span := tracer.Start(ctx, "llm_d.epp.prerequest.pd_disaggregation", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer span.End() + + if request != nil && request.TargetModel != "" { + span.SetAttributes(attribute.String("gen_ai.request.model", request.TargetModel)) + } + if request != nil && request.RequestId != "" { + span.SetAttributes(attribute.String("gen_ai.request.id", request.RequestId)) + } if _, found := request.Headers[common.PrefillPodHeader]; found { request.Headers[common.PrefillPodHeader] = "" // clear header, if already set } prefillProfileRunResult, exists := schedulingResult.ProfileResults[p.prefillProfile] if !exists { + span.SetAttributes( + attribute.Bool("llm_d.epp.pd.disaggregation_used", false), + attribute.String("llm_d.epp.pd.reason", "no_prefill_profile_result"), + ) return // prefill profile failed to run or we chose not to run it, no-op in this case } targetPod := prefillProfileRunResult.TargetEndpoints[0].GetMetadata() prefillHostPort := net.JoinHostPort(targetPod.Address, targetPod.Port) request.Headers[common.PrefillPodHeader] = prefillHostPort // in the form of + + span.SetAttributes( + attribute.Bool("llm_d.epp.pd.disaggregation_used", true), + attribute.String("llm_d.epp.pd.prefill_pod_address", targetPod.Address), + attribute.String("llm_d.epp.pd.prefill_pod_port", targetPod.Port), + ) } diff --git a/pkg/plugins/profile/pd_profile_handler.go b/pkg/plugins/profile/pd_profile_handler.go index 3f5a4b9320..d0e1038f53 100644 --- a/pkg/plugins/profile/pd_profile_handler.go +++ b/pkg/plugins/profile/pd_profile_handler.go @@ -9,6 +9,9 @@ import ( "net" "strconv" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "sigs.k8s.io/controller-runtime/pkg/log" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" @@ -17,6 +20,7 @@ import ( "github.com/llm-d/llm-d-inference-scheduler/pkg/common" "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) const ( @@ -143,8 +147,35 @@ func (h *PdProfileHandler) WithName(name string) *PdProfileHandler { // previously executed cycles along with their results. func (h *PdProfileHandler) Pick(ctx context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile, profileResults map[string]*scheduling.ProfileRunResult) map[string]scheduling.SchedulerProfile { + // Start tracing span for profile picking operation + tracer := telemetry.Tracer() + ctx, span := tracer.Start(ctx, "llm_d.epp.pd.profile_handler.pick", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer span.End() + + // Set initial attributes + span.SetAttributes( + attribute.Int("llm_d.profile_handler.total_profiles", len(profiles)), + attribute.Int("llm_d.profile_handler.executed_profiles", len(profileResults)), + ) + + // Set optional request attributes if request is not nil + if request != nil { + if request.TargetModel != "" { + span.SetAttributes(attribute.String("gen_ai.request.model", request.TargetModel)) + } + if request.RequestId != "" { + span.SetAttributes(attribute.String("gen_ai.request.id", request.RequestId)) + } + } + if _, executed := profileResults[h.decodeProfile]; !executed { // if decode profile was not executed yet, first let the scheduler run the decode profile + span.SetAttributes( + attribute.String("llm_d.profile_handler.decision", "run_decode"), + attribute.String("llm_d.profile_handler.selected_profile", h.decodeProfile), + ) return map[string]scheduling.SchedulerProfile{ h.decodeProfile: profiles[h.decodeProfile], } @@ -154,24 +185,38 @@ func (h *PdProfileHandler) Pick(ctx context.Context, _ *scheduling.CycleState, r // when a profile run fails its result value is nil. we need to check decode result before continuing to prefill // check if all configured profiles have been executed, or if decode failed, no need to run more profiles. if len(profiles) == len(profileResults) || profileResults[h.decodeProfile] == nil { + span.SetAttributes( + attribute.String("llm_d.profile_handler.decision", "complete"), + attribute.Bool("llm_d.profile_handler.decode_failed", profileResults[h.decodeProfile] == nil), + ) return map[string]scheduling.SchedulerProfile{} } inputTokens, err := getUserInputLenInTokens(request) if err != nil { log.FromContext(ctx).V(logutil.DEBUG).Error(err, "Failed to get user input") + span.SetStatus(codes.Error, err.Error()) return nil } + span.SetAttributes(attribute.Int("llm_d.profile_handler.input_tokens", inputTokens)) + if h.decider != nil && h.decider.disaggregate(ctx, inputTokens, profileResults[h.decodeProfile].TargetEndpoints[0]) { metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode) // run the prefill profile + span.SetAttributes( + attribute.String("llm_d.profile_handler.decision", "prefill_decode"), + attribute.String("llm_d.profile_handler.selected_profile", h.prefillProfile), + ) return map[string]scheduling.SchedulerProfile{ h.prefillProfile: profiles[h.prefillProfile], } } metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypeDecodeOnly) + span.SetAttributes( + attribute.String("llm_d.profile_handler.decision", "decode_only"), + ) return map[string]scheduling.SchedulerProfile{} // do not run prefill } diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go index f37f24f5e7..fb539429a5 100644 --- a/pkg/plugins/scorer/precise_prefix_cache.go +++ b/pkg/plugins/scorer/precise_prefix_cache.go @@ -12,11 +12,16 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" "github.com/llm-d/llm-d-kv-cache/pkg/tokenization/types" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "sigs.k8s.io/controller-runtime/pkg/log" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" ) const ( @@ -177,9 +182,22 @@ func (s *PrecisePrefixCacheScorer) Category() scheduling.ScorerCategory { // Score scores the provided endpoint based on the KVCache index state. // The returned scores are normalized to a range of 0-1. func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + // Start tracing span for scoring operation + tracer := telemetry.Tracer() + ctx, span := tracer.Start(ctx, "llm_d.epp.scorer.prefix_cache", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer span.End() + logger := log.FromContext(ctx).WithName(s.typedName.String()) debugLogger := logger.V(logutil.DEBUG) + // Set initial attributes + span.SetAttributes( + attribute.Int("llm_d.scorer.candidate_endpoints", len(endpoints)), + ) + + // Handle pod discovery and subscriber management if s.kvEventsConfig.DiscoverPods { // update subscribers here temporarily for _, endpoint := range endpoints { @@ -200,18 +218,34 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *schedu } } + // Early return if request is nil if request == nil { debugLogger.Info("Request is nil, skipping scoring") + span.SetAttributes(attribute.String("llm_d.scorer.result", "skipped_nil_request")) return nil } + // Set optional request attributes + if request.TargetModel != "" { + span.SetAttributes(attribute.String("gen_ai.request.model", request.TargetModel)) + } + if request.RequestId != "" { + span.SetAttributes(attribute.String("gen_ai.request.id", request.RequestId)) + } + scores, err := s.getScores(ctx, request) if err != nil { logger.Error(err, "Failed to get endpoint scores") + span.SetStatus(codes.Error, err.Error()) return nil } debugLogger.Info("Got endpoint scores", "scores", scores) + // Track scoring statistics + span.SetAttributes( + attribute.Int("llm_d.scorer.scores_computed", len(scores)), + ) + endpointToKey := func(endpoint scheduling.Endpoint) (string, bool) { metadata := endpoint.GetMetadata() if metadata == nil { @@ -221,6 +255,7 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *schedu return metadata.Address, true } + // Write prefix cache state to cycle state state := &prefix.SchedulingContextState{ PrefixHashes: []prefix.BlockHash{}, PrefixCacheServers: map[prefix.ServerID]int{}, @@ -234,7 +269,28 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *schedu } cycleState.Write(plugin.StateKey(s.typedName.String()), state) - return indexedScoresToNormalizedScoredPods(endpoints, endpointToKey, scores) + normalizedScores := indexedScoresToNormalizedScoredPods(endpoints, endpointToKey, scores) + + // Calculate score distribution for observability + if len(normalizedScores) > 0 { + maxScore := 0.0 + totalScore := 0.0 + for _, score := range normalizedScores { + if score > maxScore { + maxScore = score + } + totalScore += score + } + avgScore := totalScore / float64(len(normalizedScores)) + + span.SetAttributes( + attribute.Float64("llm_d.scorer.score.max", maxScore), + attribute.Float64("llm_d.scorer.score.avg", avgScore), + attribute.Int("llm_d.scorer.endpoints_scored", len(normalizedScores)), + ) + } + + return normalizedScores } // getScores retrieves the endpoint scores from the KV-cache indexer diff --git a/pkg/sidecar/proxy/chat_completions.go b/pkg/sidecar/proxy/chat_completions.go index 5ab731a6e4..1c367b9d28 100644 --- a/pkg/sidecar/proxy/chat_completions.go +++ b/pkg/sidecar/proxy/chat_completions.go @@ -17,12 +17,23 @@ limitations under the License. package proxy import ( + "context" "net/http" "strings" + "time" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const requestStartTimeKey contextKey = "request_start_time" + var ( // ChatCompletionsPath is the OpenAI chat completions path ChatCompletionsPath = "/v1/chat/completions" @@ -32,6 +43,27 @@ var ( ) func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) { + requestStart := time.Now() + tracer := telemetry.Tracer() + ctx, span := tracer.Start(r.Context(), "llm_d.pd_proxy.request", + trace.WithSpanKind(trace.SpanKindServer), + ) + defer span.End() + + // Update request context with span and start time + ctx = context.WithValue(ctx, requestStartTimeKey, requestStart) + r = r.WithContext(ctx) + + // Set span attributes with safe defaults for nil values + requestPath := "" + if r.URL != nil { + requestPath = r.URL.Path + } + span.SetAttributes( + attribute.String("llm_d.pd_proxy.connector", s.config.Connector), + attribute.String("llm_d.pd_proxy.request_path", requestPath), + ) + var prefillHostPorts []string prefillHostPorts = r.Header.Values(common.PrefillPodHeader) @@ -56,6 +88,10 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) if len(prefillHostPort) == 0 { s.logger.V(4).Info("skip disaggregated prefill") + span.SetAttributes( + attribute.Bool("llm_d.pd_proxy.disaggregation_used", false), + attribute.String("llm_d.pd_proxy.reason", "no_prefill_header"), + ) if !s.forwardDataParallel || !s.dataParallelHandler(w, r) { s.decoderProxy.ServeHTTP(w, r) @@ -63,6 +99,12 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) return } + span.SetAttributes( + attribute.Bool("llm_d.pd_proxy.disaggregation_used", true), + attribute.String("llm_d.pd_proxy.prefill_target", prefillHostPort), + attribute.Int("llm_d.pd_proxy.prefill_candidates", numHosts), + ) + // SSRF Protection: Check if the prefill target is allowed if !s.allowlistValidator.IsAllowed(prefillHostPort) { s.logger.Error(nil, "SSRF protection: prefill target not in allowlist", @@ -70,6 +112,11 @@ func (s *Server) chatCompletionsHandler(w http.ResponseWriter, r *http.Request) "clientIP", r.RemoteAddr, "userAgent", r.Header.Get("User-Agent"), "requestPath", r.URL.Path) + span.SetAttributes( + attribute.String("llm_d.pd_proxy.error", "ssrf_protection_denied"), + attribute.String("llm_d.pd_proxy.denied_target", prefillHostPort), + ) + span.SetStatus(codes.Error, "SSRF protection: prefill target not in allowlist") http.Error(w, "Forbidden: prefill target not allowed by SSRF protection", http.StatusForbidden) return } diff --git a/pkg/sidecar/proxy/connector_nixlv2.go b/pkg/sidecar/proxy/connector_nixlv2.go index f8543a7119..c2b3001c89 100644 --- a/pkg/sidecar/proxy/connector_nixlv2.go +++ b/pkg/sidecar/proxy/connector_nixlv2.go @@ -21,8 +21,13 @@ import ( "io" "net/http" "strings" + "time" "github.com/google/uuid" + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) { @@ -57,9 +62,20 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi uuidStr := uuid.String() // Prefill Stage + tracer := telemetry.Tracer() + ctx := r.Context() + + ctx, prefillSpan := tracer.Start(ctx, "llm_d.pd_proxy.prefill", + trace.WithSpanKind(trace.SpanKindInternal), + ) + prefillSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.request_id", uuidStr), + attribute.String("llm_d.pd_proxy.prefill_target", prefillPodHostPort), + attribute.String("llm_d.pd_proxy.connector", "nixlv2"), + ) + prefillStart := time.Now() // 1. Prepare prefill request - ctx := r.Context() preq := r.Clone(ctx) preq.Header.Add(requestHeaderRequestID, uuidStr) @@ -107,11 +123,20 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi pw := &bufferedResponseWriter{} prefillHandler.ServeHTTP(pw, preq) + prefillDuration := time.Since(prefillStart) + prefillSpan.SetAttributes( + attribute.Int("llm_d.pd_proxy.prefill.status_code", pw.statusCode), + attribute.Float64("llm_d.pd_proxy.prefill.duration_ms", float64(prefillDuration.Milliseconds())), + ) + if isHTTPError(pw.statusCode) { s.logger.Error(err, "request failed", "code", pw.statusCode) + prefillSpan.SetStatus(codes.Error, "prefill request failed") + prefillSpan.End() w.WriteHeader(pw.statusCode) return } + prefillSpan.End() // Process response - extract p/d fields var prefillerResponse map[string]any @@ -133,15 +158,31 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi // Decode Stage + ctx, decodeSpan := tracer.Start(ctx, "llm_d.pd_proxy.decode", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer decodeSpan.End() + + decodeSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.request_id", uuidStr), + attribute.String("llm_d.pd_proxy.connector", "nixlv2"), + ) + decodeStart := time.Now() + // 1. Prepare decode request dreq := r.Clone(ctx) dreq.Header.Add(requestHeaderRequestID, uuidStr) delete(completionRequest, requestFieldStream) + streamingEnabled := false if streamOk { completionRequest[requestFieldStream] = streamValue + if streamBool, ok := streamValue.(bool); ok { + streamingEnabled = streamBool + } } + decodeSpan.SetAttributes(attribute.Bool("llm_d.pd_proxy.decode.streaming", streamingEnabled)) if streamOptionsOk { completionRequest[requestFieldStreamOptions] = streamOptionsValue } @@ -168,8 +209,40 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi // 2. Forward to local decoder. s.logger.V(5).Info("sending request to decoder", "body", string(dbody)) - if !s.forwardDataParallel || !s.dataParallelHandler(w, dreq) { + dataParallelUsed := s.forwardDataParallel && s.dataParallelHandler(w, dreq) + decodeSpan.SetAttributes(attribute.Bool("llm_d.pd_proxy.decode.data_parallel", dataParallelUsed)) + + if !dataParallelUsed { s.logger.V(4).Info("sending request to decoder", "to", s.decoderURL.Host) + decodeSpan.SetAttributes(attribute.String("llm_d.pd_proxy.decode.target", s.decoderURL.Host)) s.decoderProxy.ServeHTTP(w, dreq) } + + decodeDuration := time.Since(decodeStart) + decodeSpan.SetAttributes(attribute.Float64("llm_d.pd_proxy.decode.duration_ms", float64(decodeDuration.Milliseconds()))) + + // Calculate end-to-end P/D timing metrics. + // True TTFT captures time from gateway request start to decode start, including + // gateway routing, scheduling, prefill, and coordination overhead that + // per-instance vLLM metrics miss. + if currentSpan := trace.SpanFromContext(ctx); currentSpan.SpanContext().IsValid() { + var totalDuration time.Duration + var trueTTFT time.Duration + if requestStartValue := ctx.Value(requestStartTimeKey); requestStartValue != nil { + if requestStart, ok := requestStartValue.(time.Time); ok { + totalDuration = time.Since(requestStart) + trueTTFT = decodeStart.Sub(requestStart) + } + } + + coordinatorOverhead := decodeStart.Sub(prefillStart.Add(prefillDuration)) + + currentSpan.SetAttributes( + attribute.Float64("llm_d.pd_proxy.total_duration_ms", float64(totalDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.true_ttft_ms", float64(trueTTFT.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.prefill_duration_ms", float64(prefillDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.decode_duration_ms", float64(decodeDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.coordinator_overhead_ms", float64(coordinatorOverhead.Milliseconds())), + ) + } } diff --git a/pkg/sidecar/proxy/connector_sglang.go b/pkg/sidecar/proxy/connector_sglang.go index 1e82513804..ee11a94ab5 100644 --- a/pkg/sidecar/proxy/connector_sglang.go +++ b/pkg/sidecar/proxy/connector_sglang.go @@ -28,6 +28,11 @@ import ( "strconv" "strings" "time" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/telemetry" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" ) var ( @@ -77,6 +82,20 @@ func (s *Server) runSGLangProtocol(w http.ResponseWriter, r *http.Request, prefi } func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Request, body []byte, prefillHost string) { + tracer := telemetry.Tracer() + ctx := r.Context() + + // Prefill Stage - async + ctx, prefillSpan := tracer.Start(ctx, "llm_d.pd_proxy.prefill", + trace.WithSpanKind(trace.SpanKindInternal), + ) + prefillSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.prefill_target", prefillHost), + attribute.String("llm_d.pd_proxy.connector", "sglang"), + attribute.Bool("llm_d.pd_proxy.prefill.async", true), + ) + prefillStart := time.Now() + // Create separate requests for prefill and decode // Use context.WithoutCancel for prefillReq to prevent it from being aborted // if the main HTTP handler (which serves decodeReq) finishes first. @@ -85,6 +104,8 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req prefillHandler, err := s.prefillerProxyHandler(prefillHost) if err != nil { + prefillSpan.SetStatus(codes.Error, "failed to create prefill handler") + prefillSpan.End() if err := errorBadGateway(err, w); err != nil { s.logger.Error(err, "failed to send error response to client") } @@ -93,6 +114,7 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req // Send prefill request asynchronously go func() { + defer prefillSpan.End() defer func() { if rec := recover(); rec != nil && rec != http.ErrAbortHandler { s.logger.Error(fmt.Errorf("panic: %v", rec), "panic in prefill request") @@ -100,11 +122,59 @@ func (s *Server) sendSGLangConcurrentRequests(w http.ResponseWriter, r *http.Req }() pw := &bufferedResponseWriter{} prefillHandler.ServeHTTP(pw, prefillReq) + prefillDuration := time.Since(prefillStart) + prefillSpan.SetAttributes( + attribute.Int("llm_d.pd_proxy.prefill.status_code", pw.statusCode), + attribute.Float64("llm_d.pd_proxy.prefill.duration_ms", float64(prefillDuration.Milliseconds())), + ) + if pw.statusCode < 200 || pw.statusCode >= 300 { + prefillSpan.SetStatus(codes.Error, "prefill request failed") + } s.logger.V(5).Info("prefill request completed", "status", pw.statusCode) }() + // Decode Stage - sync + ctx, decodeSpan := tracer.Start(ctx, "llm_d.pd_proxy.decode", + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer decodeSpan.End() + + decodeSpan.SetAttributes( + attribute.String("llm_d.pd_proxy.connector", "sglang"), + attribute.Bool("llm_d.pd_proxy.decode.concurrent_with_prefill", true), + ) + decodeStart := time.Now() + // Send decode request synchronously + decodeReq = decodeReq.WithContext(ctx) s.decoderProxy.ServeHTTP(w, decodeReq) + + decodeDuration := time.Since(decodeStart) + decodeSpan.SetAttributes( + attribute.Float64("llm_d.pd_proxy.decode.duration_ms", float64(decodeDuration.Milliseconds())), + attribute.String("llm_d.pd_proxy.decode.target", s.decoderURL.Host), + ) + + // Calculate end-to-end P/D timing metrics for concurrent P/D. + // True TTFT captures time from gateway request start to decode start. + // In SGLang's concurrent mode, prefill duration is tracked in the async prefill span. + if currentSpan := trace.SpanFromContext(ctx); currentSpan.SpanContext().IsValid() { + var totalDuration time.Duration + var trueTTFT time.Duration + if requestStartValue := ctx.Value(requestStartTimeKey); requestStartValue != nil { + if requestStart, ok := requestStartValue.(time.Time); ok { + totalDuration = time.Since(requestStart) + trueTTFT = decodeStart.Sub(requestStart) + } + } + + currentSpan.SetAttributes( + attribute.Float64("llm_d.pd_proxy.total_duration_ms", float64(totalDuration.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.true_ttft_ms", float64(trueTTFT.Milliseconds())), + attribute.Float64("llm_d.pd_proxy.decode_duration_ms", float64(decodeDuration.Milliseconds())), + attribute.Bool("llm_d.pd_proxy.concurrent_pd", true), + ) + } } func cloneWithJSONBody(ctx context.Context, r *http.Request, body []byte) *http.Request { diff --git a/pkg/sidecar/proxy/proxy_helpers.go b/pkg/sidecar/proxy/proxy_helpers.go index 30ba764943..8be0623636 100644 --- a/pkg/sidecar/proxy/proxy_helpers.go +++ b/pkg/sidecar/proxy/proxy_helpers.go @@ -10,6 +10,8 @@ import ( "net/url" "syscall" "time" + + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) // startHTTP starts the HTTP reverse proxy. @@ -27,8 +29,19 @@ func (s *Server) startHTTP(ctx context.Context, cert *tls.Certificate) error { } s.addr = ln.Addr() + // Wrap handler with OpenTelemetry middleware to extract trace context from incoming requests + handler := otelhttp.NewHandler(s.handler, "llm-d-pd-proxy", + otelhttp.WithSpanNameFormatter(func(_ string, r *http.Request) string { + path := "" + if r.URL != nil { + path = r.URL.Path + } + return "llm_d.pd_proxy." + r.Method + " " + path + }), + ) + server := &http.Server{ - Handler: s.handler, + Handler: handler, // No ReadTimeout/WriteTimeout for LLM inference - can take hours for large contexts IdleTimeout: 300 * time.Second, // 5 minutes for keep-alive connections ReadHeaderTimeout: 30 * time.Second, // Reasonable for headers only diff --git a/pkg/telemetry/tracing.go b/pkg/telemetry/tracing.go new file mode 100644 index 0000000000..3c12a42999 --- /dev/null +++ b/pkg/telemetry/tracing.go @@ -0,0 +1,130 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package telemetry provides OpenTelemetry tracing initialization and utilities +// for distributed tracing across llm-d components. +package telemetry + +import ( + "context" + "fmt" + "os" + "strconv" + + "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.24.0" + "go.opentelemetry.io/otel/trace" + "sigs.k8s.io/controller-runtime/pkg/log" +) + +const ( + defaultServiceName = "llm-d-inference-scheduler" + + // instrumentationName identifies this instrumentation library in traces. + instrumentationName = "llm-d-inference-scheduler" +) + +// InitTracing initializes OpenTelemetry tracing with OTLP exporter. +// Configuration is done via environment variables: +// - OTEL_SERVICE_NAME: Service name for tracing (default: llm-d-inference-scheduler) +// - OTEL_EXPORTER_OTLP_ENDPOINT: OTLP collector endpoint (default: http://localhost:4317) +// - OTEL_TRACES_SAMPLER: Sampling strategy (default: parentbased_traceidratio) +// - OTEL_TRACES_SAMPLER_ARG: Sampling ratio (default: 0.1 for 10%) +func InitTracing(ctx context.Context) (func(context.Context) error, error) { + logger := log.FromContext(ctx) + + // Get service name from environment, fallback to default + serviceName := os.Getenv("OTEL_SERVICE_NAME") + if serviceName == "" { + serviceName = defaultServiceName + } + + // Get OTLP endpoint from environment + endpoint := os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT") + if endpoint == "" { + endpoint = "localhost:4317" + } + + // Strip http:// or https:// prefix if present + // otlptracegrpc.WithEndpoint() expects host:port only + endpoint = common.StripScheme(endpoint) + + logger.Info("Initializing OpenTelemetry tracing", "endpoint", endpoint, "service", serviceName) + + // Create OTLP trace exporter + exporter, err := otlptracegrpc.New(ctx, + otlptracegrpc.WithEndpoint(endpoint), + otlptracegrpc.WithInsecure(), // Use WithTLSCredentials() in production + ) + if err != nil { + return nil, fmt.Errorf("failed to create OTLP trace exporter: %w", err) + } + + // Create resource with service name + res, err := resource.New(ctx, + resource.WithAttributes( + semconv.ServiceName(serviceName), + ), + ) + if err != nil { + return nil, fmt.Errorf("failed to create resource: %w", err) + } + + // Get sampling ratio from environment, fallback to default + samplingRatio := 0.1 // default 10% sampling + if arg := os.Getenv("OTEL_TRACES_SAMPLER_ARG"); arg != "" { + if ratio, err := strconv.ParseFloat(arg, 64); err == nil && ratio >= 0.0 && ratio <= 1.0 { + samplingRatio = ratio + } else { + logger.Info("Invalid OTEL_TRACES_SAMPLER_ARG, using default", "arg", arg, "default", samplingRatio) + } + } + + logger.Info("Configuring trace sampling", "ratio", samplingRatio) + + // Create trace provider with parent-based sampling + tp := sdktrace.NewTracerProvider( + sdktrace.WithBatcher(exporter), + sdktrace.WithResource(res), + sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(samplingRatio))), + ) + + // Set global trace provider + otel.SetTracerProvider(tp) + + // Set W3C trace context propagator + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator( + propagation.TraceContext{}, + propagation.Baggage{}, + )) + + logger.Info("OpenTelemetry tracing initialized successfully") + + // Return shutdown function + return tp.Shutdown, nil +} + +// Tracer returns a tracer for the inference scheduler. +// The tracer is identified by the instrumentation library name, which is +// distinct from the service name set during InitTracing(). +func Tracer() trace.Tracer { + return otel.Tracer(instrumentationName) +}