diff --git a/docs/architecture.md b/docs/architecture.md index c51a67293b..a1039cfabf 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -339,6 +339,31 @@ plugins: huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable ``` +Example configuration for automatic pod discovery in active-active multi-replica scheduler deployments: +```yaml + - type: precise-prefix-cache-scorer + parameters: + tokenProcessorConfig: + blockSize: 64 + hashSeed: "42" + indexerConfig: + tokenizersPoolConfig: + modelName: "Qwen/Qwen3-32B" + hf: + tokenizersCacheDir: "/tmp/tokenizers" + kvEventsConfig: + topicFilter: "kv@" + concurrency: 4 + discoverPods: true # enables automatic pod discovery for active-active HA + podDiscoveryConfig: + socketPort: 5556 +``` + +Where the vLLM engines are configured to emit KV-Events on port `5556` as follows: +```yaml + --kv-events-config "{\"enable_kv_cache_events\":true,\"publisher\":\"zmq\",\"endpoint\":\"tcp://*:5556\",\"topic\":\"kv@${POD_IP}@Qwen/Qwen3-32B\"}" +``` + Example configuration with all parameters set: ```yaml @@ -349,9 +374,11 @@ plugins: blockSize: 16 hashSeed: "12345" kvEventsConfig: - zmqEndpoint: tcp://*:5557 - topicFilter: kv@ - concurrency: 8 + topicFilter: "kv@" + concurrency: 4 + discoverPods: true # enables automatic pod discovery for active-active HA + podDiscoveryConfig: + socketPort: 5556 indexerConfig: prefixStoreConfig: cacheSize: 500000 diff --git a/pkg/plugins/scorer/active_request.go b/pkg/plugins/scorer/active_request.go index 84947c9921..2b0a0d5597 100644 --- a/pkg/plugins/scorer/active_request.go +++ b/pkg/plugins/scorer/active_request.go @@ -250,7 +250,7 @@ func (s *ActiveRequest) decrementPodCount(podName string) { } } -func cleanCachePeriodically(ctx context.Context, cache *ttlcache.Cache[string, *requestEntry], requestTimeout time.Duration) { +func cleanCachePeriodically[K comparable, V any](ctx context.Context, cache *ttlcache.Cache[K, V], requestTimeout time.Duration) { ticker := time.NewTicker(requestTimeout) defer ticker.Stop() diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go index 000d9845d1..39a2ac8445 100644 --- a/pkg/plugins/scorer/precise_prefix_cache.go +++ b/pkg/plugins/scorer/precise_prefix_cache.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "os" + "time" + "github.com/jellydator/ttlcache/v3" "github.com/llm-d/llm-d-kv-cache/pkg/kvcache" "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" @@ -46,7 +48,6 @@ var _ framework.Scorer = &PrecisePrefixCacheScorer{} // a new instance of the PrefixCacheTrackingPlugin. func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { - indexerConfig, err := kvcache.NewDefaultConfig() if err != nil { return nil, fmt.Errorf("failed to initialize indexer config: %w", err) @@ -113,9 +114,39 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex(), tokenProcessor) pool.Start(ctx) + subscribersManager := kvevents.NewSubscriberManager(pool) + var subscribersCache *ttlcache.Cache[string, struct{}] + + // initialize the subscribers cache only if pod discovery is enabled + if config.KVEventsConfig.DiscoverPods { + // initialize the subscribers TTL cache + subscriptionTimeout := 10 * time.Minute + subscribersCache = ttlcache.New[string, struct{}]( + ttlcache.WithTTL[string, struct{}](subscriptionTimeout), + ) + subscribersCache.OnEviction(func(ctx context.Context, reason ttlcache.EvictionReason, + item *ttlcache.Item[string, struct{}], + ) { + if reason == ttlcache.EvictionReasonExpired { + subscribersManager.RemoveSubscriber(ctx, item.Key()) + } + }) + go cleanCachePeriodically(ctx, subscribersCache, subscriptionTimeout) + } + if config.KVEventsConfig.ZMQEndpoint != "" { + // setup local subscriber to support global socket mode + if err := subscribersManager.EnsureSubscriber(ctx, "local-subscriber", + config.KVEventsConfig.ZMQEndpoint, config.KVEventsConfig.TopicFilter, false); err != nil { + return nil, fmt.Errorf("failed to create local subscriber for global socket mode: %w", err) + } + } + return &PrecisePrefixCacheScorer{ - typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType}, - kvCacheIndexer: kvCacheIndexer, + typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType}, + kvCacheIndexer: kvCacheIndexer, + subscribersCache: subscribersCache, + subscribersManager: subscribersManager, + kvEventsConfig: config.KVEventsConfig, }, nil } @@ -127,6 +158,15 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr type PrecisePrefixCacheScorer struct { typedName plugins.TypedName kvCacheIndexer *kvcache.Indexer + + // until the IGW data-layer is ready to provide endpoint events, + // we maintain a TTL cache of known pods that are discovered through + // the scoring process. If a pod is not in the received endpoints list + // during scoring for a certain period, we consider it gone and + // stop its KV events subscription. + subscribersCache *ttlcache.Cache[string, struct{}] + subscribersManager *kvevents.SubscriberManager + kvEventsConfig *kvevents.Config } // TypedName returns the typed name of the plugin. @@ -146,6 +186,26 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types. logger := log.FromContext(ctx).WithName(s.typedName.String()) debugLogger := logger.V(logutil.DEBUG) + if s.kvEventsConfig.DiscoverPods { + // update subscribers here temporarily + for _, pod := range pods { + podObj := pod.GetPod() + if podObj == nil { + continue + } + podKey := podObj.NamespacedName.String() + s.subscribersCache.Set(podKey, struct{}{}, 0) // use default TTL + + if err := s.subscribersManager.EnsureSubscriber(context.Background(), podKey, // dont use request ctx + fmt.Sprintf("tcp://%s:%d", podObj.Address, s.kvEventsConfig.PodDiscoveryConfig.SocketPort), + s.kvEventsConfig.TopicFilter, true); err != nil { + logger.Error(err, "Failed to ensure KV-events subscriber for pod", "pod", podKey, + "endpoint", podObj.Address) + continue + } + } + } + if request == nil { debugLogger.Info("Request is nil, skipping scoring") return nil