diff --git a/examples/helper/events.go b/examples/helper/events.go index 6e39adb02..7e889bc7b 100644 --- a/examples/helper/events.go +++ b/examples/helper/events.go @@ -32,12 +32,15 @@ const topic = "kv@vllm-pod1@" + testdata.ModelName func SimulateProduceEvent(ctx context.Context, publisher *Publisher) error { logger := log.FromContext(ctx) logger.Info("@@@ Simulating vLLM engine publishing BlockStored events...") + medium := "GPU" blockStoredEvent := kvevents.BlockStored{ BlockHashes: utils.SliceMap(testdata.PromptHashes, func(h uint64) any { return h }), ParentBlockHash: nil, TokenIds: []uint32{1, 2, 3}, BlockSize: 256, LoraID: nil, + Medium: &medium, + LoraName: nil, } //nolint // won't fail diff --git a/pkg/kvcache/kvevents/events.go b/pkg/kvcache/kvevents/events.go index d56a975bf..94ad27953 100644 --- a/pkg/kvcache/kvevents/events.go +++ b/pkg/kvcache/kvevents/events.go @@ -15,6 +15,8 @@ package kvevents import ( + "fmt" + "github.com/vmihailenco/msgpack/v5" ) @@ -53,6 +55,7 @@ type BlockStored struct { BlockSize int LoraID *int `msgpack:",omitempty"` Medium *string `msgpack:",omitempty"` + LoraName *string `msgpack:",omitempty"` } // ToTaggedUnion converts the BlockStored event to a tagged union format. @@ -67,6 +70,7 @@ func (bs BlockStored) ToTaggedUnion() []any { bs.BlockSize, bs.LoraID, bs.Medium, + bs.LoraName, } } @@ -102,3 +106,46 @@ func (ac AllBlocksCleared) ToTaggedUnion() []any { } func (AllBlocksCleared) isEvent() {} + +// UnmarshalKVEvent unmarshals a raw msgpack event into the event interface. +func UnmarshalKVEvent(rawEvent msgpack.RawMessage) (event, error) { + var taggedUnion []msgpack.RawMessage + if err := msgpack.Unmarshal(rawEvent, &taggedUnion); err != nil { + return nil, fmt.Errorf("failed to unmarshal tagged union: %w", err) + } + + if len(taggedUnion) < 1 { + return nil, fmt.Errorf("malformed tagged union: no tag") + } + + var tag string + if err := msgpack.Unmarshal(taggedUnion[0], &tag); err != nil { + return nil, fmt.Errorf("failed to unmarshal tag: %w", err) + } + + payloadBytes, err := msgpack.Marshal(taggedUnion[1:]) + if err != nil { + return nil, fmt.Errorf("failed to re-marshal payload parts: %w", err) + } + + var unmarshalErr error + switch tag { + case BlockStoredEventTag: + var bs BlockStored + unmarshalErr = msgpack.Unmarshal(payloadBytes, &bs) + return bs, unmarshalErr + + case BlockRemovedEventTag: + var br BlockRemoved + unmarshalErr = msgpack.Unmarshal(payloadBytes, &br) + return br, unmarshalErr + + case AllBlocksClearedEventTag: + var ac AllBlocksCleared + unmarshalErr = msgpack.Unmarshal(payloadBytes, &ac) + return ac, unmarshalErr + + default: + return nil, fmt.Errorf("unknown event tag: %s", tag) + } +} diff --git a/pkg/kvcache/kvevents/pool.go b/pkg/kvcache/kvevents/pool.go index bd0608ce5..1690e5094 100644 --- a/pkg/kvcache/kvevents/pool.go +++ b/pkg/kvcache/kvevents/pool.go @@ -188,51 +188,9 @@ func (p *Pool) processEvent(ctx context.Context, msg *Message) { events := make([]event, 0, len(eventBatch.Events)) for _, rawEvent := range eventBatch.Events { - var taggedUnion []msgpack.RawMessage - if err := msgpack.Unmarshal(rawEvent, &taggedUnion); err != nil { - debugLogger.Error(err, "Failed to unmarshal tagged union, skipping event") - continue - } - - // Handle array_like tagged union: re-marshall tail parts into a payload array - if len(taggedUnion) < 1 { - debugLogger.Error(nil, "Malformed tagged union, no tag element", "parts", len(taggedUnion)) - continue - } - payloadBytes, err := msgpack.Marshal(taggedUnion[1:]) + event, err := UnmarshalKVEvent(rawEvent) if err != nil { - debugLogger.Error(err, "Failed to re-marshal payload parts, skipping event") - continue - } - - var tag string - if err := msgpack.Unmarshal(taggedUnion[0], &tag); err != nil { - debugLogger.Error(err, "Failed to unmarshal tag from tagged union, skipping event") - continue - } - - var event event - var unmarshalErr error - switch tag { - case "BlockStored": - var bs BlockStored - unmarshalErr = msgpack.Unmarshal(payloadBytes, &bs) - event = bs - case "BlockRemoved": - var br BlockRemoved - unmarshalErr = msgpack.Unmarshal(payloadBytes, &br) - event = br - case "AllBlocksCleared": - var ac AllBlocksCleared - unmarshalErr = msgpack.Unmarshal(payloadBytes, &ac) - event = ac - default: - debugLogger.Info("Unknown event tag", "tag", tag) - continue - } - - if unmarshalErr != nil { - debugLogger.Error(unmarshalErr, "Failed to unmarshal event value", "tag", tag) + debugLogger.Error(err, "Failed to unmarshal event, skipping") continue } events = append(events, event) diff --git a/pkg/kvcache/kvevents/process_event_test.go b/pkg/kvcache/kvevents/process_event_test.go new file mode 100644 index 000000000..57223df3b --- /dev/null +++ b/pkg/kvcache/kvevents/process_event_test.go @@ -0,0 +1,103 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kvevents_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/vmihailenco/msgpack/v5" + + . "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvevents" +) + +// Helper function to create BlockStored raw msgpack message. +func createBlockStoredRaw(t *testing.T, fields []any) msgpack.RawMessage { + t.Helper() + data, err := msgpack.Marshal(fields) + if err != nil { + t.Fatalf("Failed to marshal fields: %v", err) + } + return msgpack.RawMessage(data) +} + +func TestBlockStoredMissingLoraName(t *testing.T) { + rawMsg := createBlockStoredRaw(t, []any{ + BlockStoredEventTag, // Event tag + []any{uint64(1001), uint64(1002)}, // BlockHashes + nil, // ParentBlockHash + []uint32{1, 2, 3}, // TokenIds + 256, // BlockSize + 42, // LoraID + "GPU", // Medium + // LoraName is missing + }) + + _, err := UnmarshalKVEvent(rawMsg) + + // Expect error due to missing LoraName + require.Error(t, err) +} + +func TestBlockStoredAllFieldsPresent(t *testing.T) { + rawMsg := createBlockStoredRaw(t, []any{ + BlockStoredEventTag, // Event tag + []any{uint64(1001), uint64(1002)}, // BlockHashes + nil, // ParentBlockHash + []uint32{1, 2, 3}, // TokenIds + 256, // BlockSize + 42, // LoraID + "GPU", // Medium + "test-lora", // LoraName + }) + + event, err := UnmarshalKVEvent(rawMsg) + + require.NoError(t, err, "Expected no error during unmarshaling") + require.NotNil(t, event, "Expected event to be non-nil") + + blockStored, ok := event.(BlockStored) + require.True(t, ok, "Expected event to be of type BlockStored") + + if blockStored.Medium == nil || *blockStored.Medium != "GPU" { + t.Errorf("Expected Medium to be 'GPU', got %v", blockStored.Medium) + } + require.NotNil(t, blockStored.Medium, "Expected Medium to be non-nil") + require.Equal(t, "GPU", *blockStored.Medium, "Expected Medium to be 'GPU'") + + require.NotNil(t, blockStored.LoraName, "Expected LoraName to be non-nil") + require.Equal(t, "test-lora", *blockStored.LoraName, "Expected LoraName to be 'test-lora'") +} + +func TestUnmarshalKVEventErrors(t *testing.T) { + // Test unknown event tag + rawMsg := createBlockStoredRaw(t, []any{ + BlockStoredEventTag, // Event tag + []any{uint64(1001), uint64(1002)}, // BlockHashes + nil, // ParentBlockHash + []uint32{1, 2, 3}, // TokenIds + }) + + var err error + _, err = UnmarshalKVEvent(rawMsg) + require.Error(t, err, "Expected error for incomplete BlockStored event") + + // Test malformed union (empty array) + emptyRawMsg := createBlockStoredRaw(t, []any{}) + _, err = UnmarshalKVEvent(emptyRawMsg) + require.Error(t, err, "Expected error for malformed tagged union") +}