Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ go.work.sum
.vscode

vendor

# Build output
/build
15 changes: 10 additions & 5 deletions pkg/plugins/profile/pd_profile_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ const (

defaultDecodeProfile = "decode"
defaultPrefillProfile = "prefill"
defaultPrefixPluginName = prefix.PrefixCachePluginType
defaultPrefixPluginType = prefix.PrefixCachePluginType
)

type pdProfileHandlerParameters struct {
Threshold int `json:"threshold"`
DecodeProfile string `json:"decodeProfile"`
PrefillProfile string `json:"prefillProfile"`
PrefixPluginType string `json:"prefixPluginType"`
PrefixPluginName string `json:"prefixPluginName"`
HashBlockSize int `json:"hashBlockSize"`
PrimaryPort int `json:"primaryPort"`
Expand All @@ -48,7 +49,7 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi
Threshold: 0,
DecodeProfile: defaultDecodeProfile,
PrefillProfile: defaultPrefillProfile,
PrefixPluginName: defaultPrefixPluginName,
PrefixPluginType: defaultPrefixPluginType,
HashBlockSize: prefix.DefaultBlockSize,
PrimaryPort: 0,
}
Expand All @@ -58,6 +59,10 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi
}
}

if parameters.PrefixPluginName == "" {
parameters.PrefixPluginName = parameters.PrefixPluginType
}

if parameters.Threshold < 0 {
return nil, fmt.Errorf("invalid threshold: must be >= 0, got %d", parameters.Threshold)
}
Expand All @@ -72,15 +77,15 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi
}
}

return NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginName,
return NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginType, parameters.PrefixPluginName,
parameters.Threshold, parameters.HashBlockSize, parameters.PrimaryPort).WithName(name), nil
}

// NewPdProfileHandler initializes a new PdProfileHandler and returns its pointer.
func NewPdProfileHandler(prefillProfile string, decodeProfile string, prefixPluginName string, pdThreshold int, hashBlockSize int, primaryPort int) *PdProfileHandler {
func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string, pdThreshold, hashBlockSize, primaryPort int) *PdProfileHandler {
result := &PdProfileHandler{
typedName: plugins.TypedName{Type: PdProfileHandlerType},
prefixPluginTypedName: plugins.TypedName{Type: prefix.PrefixCachePluginType, Name: prefixPluginName},
prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName},
decodeProfile: decodeProfile,
prefillProfile: prefillProfile,
pdThreshold: pdThreshold,
Expand Down
8 changes: 8 additions & 0 deletions pkg/plugins/profile/pd_profile_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ func TestPdProfileHandler_Pick(t *testing.T) {
name string
pdThreshold int
hashBlockSize int
prefixPluginType string
prefixPluginName string
setupPrefixState func(*types.CycleState)
profileResults map[string]*types.ProfileRunResult
Expand All @@ -241,6 +242,7 @@ func TestPdProfileHandler_Pick(t *testing.T) {
name: "decode not executed yet → run decode",
pdThreshold: 100,
hashBlockSize: 16,
prefixPluginType: prefix.PrefixCachePluginType,
prefixPluginName: prefix.PrefixCachePluginType,
profileResults: map[string]*types.ProfileRunResult{},
expectedProfiles: []string{"decode"},
Expand All @@ -249,6 +251,7 @@ func TestPdProfileHandler_Pick(t *testing.T) {
name: "decode failed (nil result) → run nothing",
pdThreshold: 100,
hashBlockSize: 16,
prefixPluginType: prefix.PrefixCachePluginType,
prefixPluginName: prefix.PrefixCachePluginType,
profileResults: map[string]*types.ProfileRunResult{
"decode": nil,
Expand All @@ -259,6 +262,7 @@ func TestPdProfileHandler_Pick(t *testing.T) {
name: "all profiles already executed → run nothing",
pdThreshold: 100,
hashBlockSize: 16,
prefixPluginType: prefix.PrefixCachePluginType,
prefixPluginName: prefix.PrefixCachePluginType,
profileResults: map[string]*types.ProfileRunResult{
"decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
Expand All @@ -270,6 +274,7 @@ func TestPdProfileHandler_Pick(t *testing.T) {
name: "pd threshold NOT triggered → run prefill",
pdThreshold: 5,
hashBlockSize: 16,
prefixPluginType: prefix.PrefixCachePluginType,
prefixPluginName: prefix.PrefixCachePluginType,
setupPrefixState: func(cs *types.CycleState) {
state := &prefix.SchedulingContextState{
Expand All @@ -289,6 +294,7 @@ func TestPdProfileHandler_Pick(t *testing.T) {
name: "pd threshold triggered (short non-cached suffix) → skip prefill",
pdThreshold: 100,
hashBlockSize: 16,
prefixPluginType: prefix.PrefixCachePluginType,
prefixPluginName: prefix.PrefixCachePluginType,
setupPrefixState: func(cs *types.CycleState) {
state := &prefix.SchedulingContextState{
Expand All @@ -311,6 +317,7 @@ func TestPdProfileHandler_Pick(t *testing.T) {
handler := NewPdProfileHandler(
"prefill",
"decode",
tt.prefixPluginType,
tt.prefixPluginName,
tt.pdThreshold,
tt.hashBlockSize,
Expand Down Expand Up @@ -402,6 +409,7 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) {
"prefill",
"decode",
prefix.PrefixCachePluginType,
prefix.PrefixCachePluginType,
0,
prefix.DefaultBlockSize,
tt.primaryPort,
Expand Down
25 changes: 16 additions & 9 deletions pkg/plugins/scorer/no_hit_lru.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ var _ requestcontrol.PreRequest = &NoHitLRU{}

// NoHitLRUParameters defines the parameters for the NoHitLRU scorer.
type NoHitLRUParameters struct {
// PrefixPluginType defines the type of the prefix cache plugin to read state from.
// Defaults to "prefix-cache-scorer".
PrefixPluginType string `json:"prefixPluginType"`
// PrefixPluginName defines the name of the prefix cache plugin to read state from.
// Defaults to "prefix-cache-scorer".
PrefixPluginName string `json:"prefixPluginName"`
Expand Down Expand Up @@ -69,10 +72,14 @@ func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugins.

// NewNoHitLRU creates a new NoHitLRU scorer
func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU {
prefixPluginType := prefix.PrefixCachePluginType
prefixPluginName := prefix.PrefixCachePluginType
lruSize := defaultLRUSize

if params != nil {
if params.PrefixPluginType != "" {
prefixPluginType = params.PrefixPluginType
}
if params.PrefixPluginName != "" {
prefixPluginName = params.PrefixPluginName
}
Expand All @@ -88,21 +95,21 @@ func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU {
}

return &NoHitLRU{
typedName: plugins.TypedName{Type: NoHitLRUType},
lruCache: lruCache,
prefixPluginName: prefixPluginName,
pluginState: plugins.NewPluginState(ctx),
typedName: plugins.TypedName{Type: NoHitLRUType},
lruCache: lruCache,
prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName},
pluginState: plugins.NewPluginState(ctx),
}
}

// NoHitLRU scorer that favors pods that were least recently used for cold requests.
// This can help evenly distribute cache growth, since cold requests result in more
// new KV blocks.
type NoHitLRU struct {
typedName plugins.TypedName
lruCache *lru.Cache[string, struct{}] // pod name -> dummy value (we only care about order)
prefixPluginName string
pluginState *plugins.PluginState
typedName plugins.TypedName
lruCache *lru.Cache[string, struct{}] // pod name -> dummy value (we only care about order)
prefixPluginTypedName plugins.TypedName
pluginState *plugins.PluginState
}

// TypedName returns the typed name of the plugin.
Expand All @@ -123,7 +130,7 @@ func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleSta

// Read prefix cache state to determine if this is a cold request
// This is treated as an optimization - if the state isn't available, we assume cold request
prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(s.prefixPluginName))
prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(s.prefixPluginTypedName.String()))

if err != nil {
logger.Info("No prefix cache state found, treating as cold request for LRU optimization", "error", err)
Expand Down
18 changes: 12 additions & 6 deletions pkg/plugins/scorer/no_hit_lru_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ func TestNoHitLRUScorer(t *testing.T) {
// Create cycle state and set prefix state
cycleState := &types.CycleState{}
if test.prefixState != nil {
cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), test.prefixState)
cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), test.prefixState)
}

got := test.scorer.Score(context.Background(), cycleState, test.req, test.input)
Expand Down Expand Up @@ -234,7 +235,8 @@ func TestNoHitLRUBasicFunctionality(t *testing.T) {
PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request
}
cycleState := &types.CycleState{}
cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), coldPrefixState)
cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), coldPrefixState)

scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods)

Expand Down Expand Up @@ -295,7 +297,8 @@ func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) {
primaryProfile := "primary-profile"
toPrefixState := func(entries map[prefix.ServerID]int) *types.CycleState {
cycle := &types.CycleState{}
cycle.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{PrefixCacheServers: entries})
cycle.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{PrefixCacheServers: entries})
return cycle
}

Expand Down Expand Up @@ -406,7 +409,8 @@ func TestNoHitLRUEdgeCases(t *testing.T) {
t.Run("empty pods list", func(t *testing.T) {
emptyPods := []types.Pod{}
cycleState := &types.CycleState{}
cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{
cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // cold request
})

Expand All @@ -419,7 +423,8 @@ func TestNoHitLRUEdgeCases(t *testing.T) {

t.Run("nil pods list", func(t *testing.T) {
cycleState := &types.CycleState{}
cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{
cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // cold request
})

Expand All @@ -436,7 +441,8 @@ func TestNoHitLRUEdgeCases(t *testing.T) {
t.Run("single pod returns 1.0", func(t *testing.T) {
pods := []types.Pod{podA}
cycleState := &types.CycleState{}
cycleState.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{
cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // cold request
})

Expand Down
16 changes: 15 additions & 1 deletion pkg/plugins/scorer/precise_prefix_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)
Expand Down Expand Up @@ -125,7 +126,7 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor

// Score scores the provided pod based on the KVCache index state.
// The returned scores are normalized to a range of 0-1.
func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
logger := log.FromContext(ctx).WithName(s.typedName.String())
debugLogger := logger.V(logutil.DEBUG)

Expand All @@ -150,6 +151,19 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, _ *types.CycleStat
return metricsPod.Address, true
}

state := &prefix.SchedulingContextState{
PrefixHashes: []prefix.BlockHash{},
PrefixCacheServers: map[prefix.ServerID]int{},
}
for _, pod := range pods {
key, ok := podToKey(pod)
if !ok {
continue
}
state.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] = int(scores[key])
}
cycleState.Write(plugins.StateKey(s.typedName.String()), state)

return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/plugins/scorer/precise_prefix_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
}
}

got := prefixCacheScorer.Score(ctx, nil, tt.request, tt.pods)
got := prefixCacheScorer.Score(ctx, types.NewCycleState(), tt.request, tt.pods)

gotByAddress := make(map[string]float64)
for pod, score := range got {
Expand Down
2 changes: 1 addition & 1 deletion pkg/scheduling/pd/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func TestPDSchedule(t *testing.T) {
err = decodeSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 0))
assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error")

profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Name, 10, 5, 0)
profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, 10, 5, 0)

schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]*framework.SchedulerProfile{
prefill: prefillSchedulerProfile,
Expand Down