diff --git a/cmd/epp/runner/runner.go b/cmd/epp/runner/runner.go index 699585ede..ad0760886 100644 --- a/cmd/epp/runner/runner.go +++ b/cmd/epp/runner/runner.go @@ -473,9 +473,22 @@ func (r *Runner) parseConfigurationPhaseOne(ctx context.Context) (*configapi.End return rawConfig, nil } +// Return a function that can be used in the EPP Handle to list pod names. +func makePodListFunc(ds datastore.Datastore) func() []types.NamespacedName { + return func() []types.NamespacedName { + pods := ds.PodList(func(_ backendmetrics.PodMetrics) bool { return true }) + names := make([]types.NamespacedName, 0, len(pods)) + + for _, p := range pods { + names = append(names, p.GetPod().NamespacedName) + } + return names + } +} + func (r *Runner) parseConfigurationPhaseTwo(ctx context.Context, rawConfig *configapi.EndpointPickerConfig, ds datastore.Datastore) (*config.Config, error) { logger := log.FromContext(ctx) - handle := plugins.NewEppHandle(ctx, ds.PodList) + handle := plugins.NewEppHandle(ctx, makePodListFunc(ds)) cfg, err := loader.LoadConfigPhaseTwo(rawConfig, handle, logger) if err != nil { @@ -604,8 +617,7 @@ func setupDatalayer(logger logr.Logger) (datalayer.EndpointFactory, error) { // create and register a metrics data source and extractor. source := dlmetrics.NewDataSource(*modelServerMetricsScheme, *modelServerMetricsPath, - *modelServerMetricsHttpsInsecureSkipVerify, - nil) + *modelServerMetricsHttpsInsecureSkipVerify) extractor, err := dlmetrics.NewExtractor(*totalQueuedRequestsMetric, *totalRunningRequestsMetric, *kvCacheUsagePercentageMetric, @@ -624,7 +636,7 @@ func setupDatalayer(logger logr.Logger) (datalayer.EndpointFactory, error) { // TODO: this could be moved to the configuration loading functions once ported over. sources := datalayer.GetSources() for _, src := range sources { - logger.Info("data layer configuration", "source", src.Name(), "extractors", src.Extractors()) + logger.Info("data layer configuration", "source", src.TypedName().String(), "extractors", src.Extractors()) } factory := datalayer.NewEndpointFactory(sources, *refreshMetricsInterval) return factory, nil diff --git a/pkg/epp/datalayer/collector_test.go b/pkg/epp/datalayer/collector_test.go index f0655a7c7..22408fcdf 100644 --- a/pkg/epp/datalayer/collector_test.go +++ b/pkg/epp/datalayer/collector_test.go @@ -27,6 +27,7 @@ import ( "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/mocks" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" ) // --- Test Stubs --- @@ -35,7 +36,16 @@ type DummySource struct { callCount int64 } -func (d *DummySource) Name() string { return "test-dummy-data-source" } +const ( + dummySource = "test-dummy-data-source" +) + +func (d *DummySource) TypedName() plugins.TypedName { + return plugins.TypedName{ + Type: dummySource, + Name: dummySource, + } +} func (d *DummySource) Extractors() []string { return []string{} } func (d *DummySource) AddExtractor(_ Extractor) error { return nil } func (d *DummySource) Collect(ctx context.Context, ep Endpoint) error { diff --git a/pkg/epp/datalayer/config.go b/pkg/epp/datalayer/config.go new file mode 100644 index 000000000..ed4e91a03 --- /dev/null +++ b/pkg/epp/datalayer/config.go @@ -0,0 +1,29 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datalayer + +// Config defines the configuration of EPP data layer, as the set of DataSources and +// Extractors defined on them. +type Config struct { + Sources []DataSourceConfig // the data sources configured in the data layer +} + +// DataSourceConfig defines the configuration of a specific DataSource +type DataSourceConfig struct { + Plugin DataSource // the data source plugin instance + Extractors []Extractor // extractors defined for the data source +} diff --git a/pkg/epp/datalayer/datasource.go b/pkg/epp/datalayer/datasource.go index 40fd365b9..5e0f35878 100644 --- a/pkg/epp/datalayer/datasource.go +++ b/pkg/epp/datalayer/datasource.go @@ -22,12 +22,13 @@ import ( "fmt" "reflect" "sync" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" ) // DataSource provides raw data to registered Extractors. type DataSource interface { - // Name of this data source. - Name() string + plugins.Plugin // Extractors returns a list of registered Extractor names. Extractors() []string // AddExtractor adds an extractor to the data source. Multiple @@ -45,7 +46,7 @@ type DataSource interface { // Extractor transforms raw data into structured attributes. type Extractor interface { - Name() string + plugins.Plugin // ExpectedType defines the type expected by the extractor. ExpectedInputType() reflect.Type // Extract transforms the raw data source output into a concrete structured @@ -65,22 +66,12 @@ func (dsr *DataSourceRegistry) Register(src DataSource) error { if src == nil { return errors.New("unable to register a nil data source") } - if _, loaded := dsr.sources.LoadOrStore(src.Name(), src); loaded { - return fmt.Errorf("unable to register duplicate data source: %s", src.Name()) + if _, loaded := dsr.sources.LoadOrStore(src.TypedName().Name, src); loaded { + return fmt.Errorf("unable to register duplicate data source: %s", src.TypedName().String()) } return nil } -// GetNamedSource fetches a source by name. -func (dsr *DataSourceRegistry) GetNamedSource(name string) (DataSource, bool) { - if val, ok := dsr.sources.Load(name); ok { - if ds, ok := val.(DataSource); ok { - return ds, true - } - } - return nil, false -} - // GetSources returns all registered sources. func (dsr *DataSourceRegistry) GetSources() []DataSource { var result []DataSource @@ -100,21 +91,6 @@ func RegisterSource(src DataSource) error { return defaultDataSources.Register(src) } -// GetNamedSource returns a typed data source from the default registry. -func GetNamedSource[T DataSource](name string) (T, bool) { - v, ok := defaultDataSources.GetNamedSource(name) - if !ok { - var zero T - return zero, false - } - src, ok := v.(T) - if !ok { - var zero T - return zero, false - } - return src, true -} - // GetSources returns the list of data sources registered in the default registry. func GetSources() []DataSource { return defaultDataSources.GetSources() diff --git a/pkg/epp/datalayer/datasource_test.go b/pkg/epp/datalayer/datasource_test.go index e5cb41dfd..7ac262a5c 100644 --- a/pkg/epp/datalayer/datasource_test.go +++ b/pkg/epp/datalayer/datasource_test.go @@ -22,20 +22,26 @@ import ( "testing" "github.com/stretchr/testify/assert" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" +) + +const ( + testType = "test-ds-type" ) type mockDataSource struct { - name string + tn plugins.TypedName } -func (m *mockDataSource) Name() string { return m.name } +func (m *mockDataSource) TypedName() plugins.TypedName { return m.tn } func (m *mockDataSource) Extractors() []string { return []string{} } func (m *mockDataSource) AddExtractor(_ Extractor) error { return nil } func (m *mockDataSource) Collect(_ context.Context, _ Endpoint) error { return nil } func TestRegisterAndGetSource(t *testing.T) { reg := DataSourceRegistry{} - ds := &mockDataSource{name: "test"} + ds := &mockDataSource{tn: plugins.TypedName{Type: testType, Name: testType}} err := reg.Register(ds) assert.NoError(t, err, "expected no error on first registration") @@ -47,35 +53,25 @@ func TestRegisterAndGetSource(t *testing.T) { err = reg.Register(nil) assert.Error(t, err, "expected error on nil") - // Get by name - got, found := reg.GetNamedSource("test") - assert.True(t, found, "expected to find registered data source") - assert.Equal(t, "test", got.Name()) - // Get all sources all := reg.GetSources() assert.Len(t, all, 1) - assert.Equal(t, "test", all[0].Name()) + assert.Equal(t, testType, all[0].TypedName().Type) // Default registry err = RegisterSource(ds) assert.NoError(t, err, "expected no error on registration") - // Get by name - got, found = GetNamedSource[*mockDataSource]("test") - assert.True(t, found, "expected to find registered data source") - assert.Equal(t, "test", got.Name()) - // Get all sources all = GetSources() assert.Len(t, all, 1) - assert.Equal(t, "test", all[0].Name()) + assert.Equal(t, testType, all[0].TypedName().Type) } -func TestGetNamedSourceWhenNotFound(t *testing.T) { +func TestGetSourceWhenNoneAreRegistered(t *testing.T) { reg := DataSourceRegistry{} - _, found := reg.GetNamedSource("missing") - assert.False(t, found, "expected source to be missing") + found := reg.GetSources() + assert.Empty(t, found, "expected no sources to be returned") } func TestValidateExtractorType(t *testing.T) { diff --git a/pkg/epp/datalayer/metrics/datasource.go b/pkg/epp/datalayer/metrics/datasource.go index 81723d4e0..d5940ac65 100644 --- a/pkg/epp/datalayer/metrics/datasource.go +++ b/pkg/epp/datalayer/metrics/datasource.go @@ -25,15 +25,17 @@ import ( "sync" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" ) const ( - DataSourceName = "metrics-data-source" + DataSourceType = "metrics-data-source" ) // DataSource is a Model Server Protocol (MSP) compliant metrics data source, // returning Prometheus formatted metrics for an endpoint. type DataSource struct { + tn plugins.TypedName metricsScheme string // scheme to use in metrics URL metricsPath string // path to use in metrics URL @@ -42,10 +44,9 @@ type DataSource struct { } // NewDataSource returns a new MSP compliant metrics data source, configured with -// the provided client factory. If ClientFactory is nil, a default factory is used. -// The Scheme, port and path are command line options. It should be noted that -// a port value of zero is set if the command line is unspecified. -func NewDataSource(metricsScheme string, metricsPath string, skipCertVerification bool, cl Client) *DataSource { +// the provided client configuration. +// The Scheme, path and certificate validation setting are command line options. +func NewDataSource(metricsScheme string, metricsPath string, skipCertVerification bool) *DataSource { if metricsScheme == "https" { httpsTransport := baseTransport.Clone() httpsTransport.TLSClientConfig = &tls.Config{ @@ -54,33 +55,33 @@ func NewDataSource(metricsScheme string, metricsPath string, skipCertVerificatio defaultClient.Transport = httpsTransport } - if cl == nil { - cl = defaultClient - } - dataSrc := &DataSource{ + tn: plugins.TypedName{ + Type: DataSourceType, + Name: DataSourceType, + }, metricsScheme: metricsScheme, metricsPath: metricsPath, - client: cl, + client: defaultClient, } return dataSrc } -// Name returns the metrics data source name. -func (dataSrc *DataSource) Name() string { - return DataSourceName +// TypedName returns the metrics data source type and name. +func (dataSrc *DataSource) TypedName() plugins.TypedName { + return dataSrc.tn } // Extractors returns a list of registered Extractor names. func (dataSrc *DataSource) Extractors() []string { - names := []string{} + extractors := []string{} dataSrc.extractors.Range(func(_, val any) bool { if ex, ok := val.(datalayer.Extractor); ok { - names = append(names, ex.Name()) + extractors = append(extractors, ex.TypedName().String()) } return true // continue iteration }) - return names + return extractors } // AddExtractor adds an extractor to the data source, validating it can process @@ -89,8 +90,8 @@ func (dataSrc *DataSource) AddExtractor(extractor datalayer.Extractor) error { if err := datalayer.ValidateExtractorType(PrometheusMetricType, extractor.ExpectedInputType()); err != nil { return err } - if _, loaded := dataSrc.extractors.LoadOrStore(extractor.Name(), extractor); loaded { - return fmt.Errorf("attempt to add extractor with duplicate name %s to %s", extractor.Name(), dataSrc.Name()) + if _, loaded := dataSrc.extractors.LoadOrStore(extractor.TypedName().Name, extractor); loaded { + return fmt.Errorf("attempt to add duplicate extractor %s to %s", extractor.TypedName(), dataSrc.TypedName()) } return nil } diff --git a/pkg/epp/datalayer/metrics/datasource_test.go b/pkg/epp/datalayer/metrics/datasource_test.go index 016622a40..4d4db2a01 100644 --- a/pkg/epp/datalayer/metrics/datasource_test.go +++ b/pkg/epp/datalayer/metrics/datasource_test.go @@ -28,12 +28,12 @@ import ( ) func TestDatasource(t *testing.T) { - source := NewDataSource("https", "/metrics", true, nil) + source := NewDataSource("https", "/metrics", true) extractor, err := NewExtractor(defaultTotalQueuedRequestsMetric, "", "", "", "") assert.Nil(t, err, "failed to create extractor") - name := source.Name() - assert.Equal(t, DataSourceName, name) + dsType := source.TypedName().Type + assert.Equal(t, DataSourceType, dsType) err = source.AddExtractor(extractor) assert.Nil(t, err, "failed to add extractor") @@ -43,7 +43,7 @@ func TestDatasource(t *testing.T) { extractors := source.Extractors() assert.Len(t, extractors, 1) - assert.Equal(t, extractor.Name(), extractors[0]) + assert.Equal(t, extractor.TypedName().String(), extractors[0]) err = datalayer.RegisterSource(source) assert.Nil(t, err, "failed to register") diff --git a/pkg/epp/datalayer/metrics/extractor.go b/pkg/epp/datalayer/metrics/extractor.go index f08ccf95e..9f450ee5b 100644 --- a/pkg/epp/datalayer/metrics/extractor.go +++ b/pkg/epp/datalayer/metrics/extractor.go @@ -30,11 +30,12 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) const ( - extractorName = "model-server-protocol-metrics" + extractorType = "model-server-protocol-metrics" // LoRA metrics based on MSP LoraInfoRunningAdaptersMetricName = "running_lora_adapters" @@ -48,6 +49,7 @@ const ( // Extractor implements the metrics extraction based on the model // server protocol standard. type Extractor struct { + tn plugins.TypedName mapping *Mapping } @@ -72,13 +74,17 @@ func NewExtractor(queueSpec, runningSpec, kvusageSpec, loraSpec, cacheInfoSpec s return nil, fmt.Errorf("failed to create extractor metrics Mapping - %w", err) } return &Extractor{ + tn: plugins.TypedName{ + Type: extractorType, + Name: extractorType, + }, mapping: mapping, }, nil } -// Name returns the name of the metrics.Extractor. -func (ext *Extractor) Name() string { - return extractorName +// TypedName returns the type and name of the metrics.Extractor. +func (ext *Extractor) TypedName() plugins.TypedName { + return ext.tn } // ExpectedType defines the type expected by the metrics.Extractor - a diff --git a/pkg/epp/datalayer/metrics/extractor_test.go b/pkg/epp/datalayer/metrics/extractor_test.go index bb408f6db..45638d9fb 100644 --- a/pkg/epp/datalayer/metrics/extractor_test.go +++ b/pkg/epp/datalayer/metrics/extractor_test.go @@ -50,8 +50,8 @@ func TestExtractorExtract(t *testing.T) { t.Fatalf("failed to create extractor: %v", err) } - if name := extractor.Name(); name == "" { - t.Error("empty extractor name") + if exType := extractor.TypedName().Type; exType == "" { + t.Error("empty extractor type") } if inputType := extractor.ExpectedInputType(); inputType != PrometheusMetricType { diff --git a/pkg/epp/plugins/handle.go b/pkg/epp/plugins/handle.go index c074e9076..c6e9c0dbe 100644 --- a/pkg/epp/plugins/handle.go +++ b/pkg/epp/plugins/handle.go @@ -20,7 +20,7 @@ import ( "context" "fmt" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "k8s.io/apimachinery/pkg/types" ) // Handle provides plugins a set of standard data and tools to work with @@ -30,8 +30,8 @@ type Handle interface { HandlePlugins - // PodList lists pods matching the given predicate. - PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics + // PodList lists pods. + PodList() []types.NamespacedName } // HandlePlugins defines a set of APIs to work with instantiated plugins @@ -50,7 +50,7 @@ type HandlePlugins interface { } // PodListFunc is a function type that filters and returns a list of pod metrics -type PodListFunc func(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics +type PodListFunc func() []types.NamespacedName // eppHandle is an implementation of the interface plugins.Handle type eppHandle struct { @@ -93,9 +93,9 @@ func (h *eppHandlePlugins) GetAllPluginsWithNames() map[string]Plugin { return h.plugins } -// PodList lists pods matching the given predicate. -func (h *eppHandle) PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics { - return h.podList(predicate) +// PodList lists pods. +func (h *eppHandle) PodList() []types.NamespacedName { + return h.podList() } func NewEppHandle(ctx context.Context, podList PodListFunc) Handle { diff --git a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go index 2947785d2..2a1a3a8b2 100644 --- a/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go +++ b/pkg/epp/scheduling/framework/plugins/multi/prefix/plugin.go @@ -28,7 +28,6 @@ import ( k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" @@ -307,10 +306,10 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle) case <-ctx.Done(): return case <-ticker.C: - activePodMetrics := handle.PodList(func(_ backendmetrics.PodMetrics) bool { return true }) - activePods := make(map[ServerID]struct{}, len(activePodMetrics)) - for _, pm := range activePodMetrics { - activePods[ServerID(pm.GetPod().NamespacedName)] = struct{}{} + podNames := handle.PodList() + activePods := make(map[ServerID]struct{}, len(podNames)) + for _, nsn := range podNames { + activePods[ServerID(nsn)] = struct{}{} } for _, pod := range m.indexer.Pods() { diff --git a/test/utils/handle.go b/test/utils/handle.go index 273539f81..15dfe10a0 100644 --- a/test/utils/handle.go +++ b/test/utils/handle.go @@ -19,7 +19,8 @@ package utils import ( "context" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" ) @@ -34,8 +35,8 @@ func (h *testHandle) Context() context.Context { return h.ctx } -func (h *testHandle) PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics { - return []backendmetrics.PodMetrics{} +func (h *testHandle) PodList() []types.NamespacedName { + return []types.NamespacedName{} } type testHandlePlugins struct {