From e21b8491724179b8f49b0a971c7449e77149e62e Mon Sep 17 00:00:00 2001 From: Lorenz Bauer Date: Wed, 3 Feb 2021 17:21:09 +0000 Subject: [PATCH] map: distinguish between keys and values when marshaling Marshaling and unmarshaling map keys and values is quite complicated: * Queues allow nil keys * Some maps allow storing map and program fds * But looking them up returns map and program IDs, not fds * We allow passing unsafe.Pointer for "zero copy" lookups The current implementation isn't careful enough about these constraints. For example, Map implements BinaryMarshaler which encodes the map fd as a host endian uint32. This representation only makes sense when inserting into an ArrayOfMaps or similar. Since we don't distinguish between keys and values, you can use a *Map as a key and get the same behaviour, except that the key ends up being the value of the file descriptor. This doesn't make a lot of sense. Add helpers on Map that should be used when marshaling or unmarshaling keys and values. Re-using the same code everywhere also fixes an inconsistency with LookupAndDelete, which currently doesn't allow unsafe.Pointer. --- map.go | 189 +++++++++++++++++++++++++++++++++----------------- map_test.go | 29 ++------ marshalers.go | 27 ++++++-- prog.go | 7 +- types.go | 17 +++-- 5 files changed, 171 insertions(+), 98 deletions(-) diff --git a/map.go b/map.go index 62ec6243f..59bb2c70b 100644 --- a/map.go +++ b/map.go @@ -420,46 +420,7 @@ func (m *Map) Lookup(key, valueOut interface{}) error { return err } - if valueBytes == nil { - return nil - } - - if m.typ.hasPerCPUValue() { - return unmarshalPerCPUValue(valueOut, int(m.valueSize), valueBytes) - } - - switch value := valueOut.(type) { - case **Map: - m, err := unmarshalMap(valueBytes) - if err != nil { - return err - } - - (*value).Close() - *value = m - return nil - case *Map: - return fmt.Errorf("can't unmarshal into %T, need %T", value, (**Map)(nil)) - case Map: - return fmt.Errorf("can't unmarshal into %T, need %T", value, (**Map)(nil)) - - case **Program: - p, err := unmarshalProgram(valueBytes) - if err != nil { - return err - } - - (*value).Close() - *value = p - return nil - case *Program: - return fmt.Errorf("can't unmarshal into %T, need %T", value, (**Program)(nil)) - case Program: - return fmt.Errorf("can't unmarshal into %T, need %T", value, (**Program)(nil)) - - default: - return unmarshalBytes(valueOut, valueBytes) - } + return m.unmarshalValue(valueOut, valueBytes) } // LookupAndDelete retrieves and deletes a value from a Map. @@ -468,7 +429,7 @@ func (m *Map) Lookup(key, valueOut interface{}) error { func (m *Map) LookupAndDelete(key, valueOut interface{}) error { valuePtr, valueBytes := makeBuffer(valueOut, m.fullValueSize) - keyPtr, err := marshalPtr(key, int(m.keySize)) + keyPtr, err := m.marshalKey(key) if err != nil { return fmt.Errorf("can't marshal key: %w", err) } @@ -477,7 +438,7 @@ func (m *Map) LookupAndDelete(key, valueOut interface{}) error { return fmt.Errorf("lookup and delete failed: %w", err) } - return unmarshalBytes(valueOut, valueBytes) + return m.unmarshalValue(valueOut, valueBytes) } // LookupBytes gets a value from Map. @@ -496,7 +457,7 @@ func (m *Map) LookupBytes(key interface{}) ([]byte, error) { } func (m *Map) lookup(key interface{}, valueOut internal.Pointer) error { - keyPtr, err := marshalPtr(key, int(m.keySize)) + keyPtr, err := m.marshalKey(key) if err != nil { return fmt.Errorf("can't marshal key: %w", err) } @@ -530,17 +491,12 @@ func (m *Map) Put(key, value interface{}) error { // Update changes the value of a key. func (m *Map) Update(key, value interface{}, flags MapUpdateFlags) error { - keyPtr, err := marshalPtr(key, int(m.keySize)) + keyPtr, err := m.marshalKey(key) if err != nil { return fmt.Errorf("can't marshal key: %w", err) } - var valuePtr internal.Pointer - if m.typ.hasPerCPUValue() { - valuePtr, err = marshalPerCPUValue(value, int(m.valueSize)) - } else { - valuePtr, err = marshalPtr(value, int(m.valueSize)) - } + valuePtr, err := m.marshalValue(value) if err != nil { return fmt.Errorf("can't marshal value: %w", err) } @@ -556,7 +512,7 @@ func (m *Map) Update(key, value interface{}, flags MapUpdateFlags) error { // // Returns ErrKeyNotExist if the key does not exist. func (m *Map) Delete(key interface{}) error { - keyPtr, err := marshalPtr(key, int(m.keySize)) + keyPtr, err := m.marshalKey(key) if err != nil { return fmt.Errorf("can't marshal key: %w", err) } @@ -579,11 +535,7 @@ func (m *Map) NextKey(key, nextKeyOut interface{}) error { return err } - if nextKeyBytes == nil { - return nil - } - - if err := unmarshalBytes(nextKeyOut, nextKeyBytes); err != nil { + if err := m.unmarshalKey(nextKeyOut, nextKeyBytes); err != nil { return fmt.Errorf("can't unmarshal next key: %w", err) } return nil @@ -615,7 +567,7 @@ func (m *Map) nextKey(key interface{}, nextKeyOut internal.Pointer) error { ) if key != nil { - keyPtr, err = marshalPtr(key, int(m.keySize)) + keyPtr, err = m.marshalKey(key) if err != nil { return fmt.Errorf("can't marshal key: %w", err) } @@ -720,6 +672,116 @@ func (m *Map) populate(contents []MapKV) error { return nil } +func (m *Map) marshalKey(data interface{}) (internal.Pointer, error) { + if data == nil { + if m.keySize == 0 { + // Queues have a key length of zero, so passing nil here is valid. + return internal.NewPointer(nil), nil + } + return internal.Pointer{}, errors.New("can't use nil as key of map") + } + + return marshalPtr(data, int(m.keySize)) +} + +func (m *Map) unmarshalKey(data interface{}, buf []byte) error { + if buf == nil { + // This is from a makeBuffer call, nothing do do here. + return nil + } + + return unmarshalBytes(data, buf) +} + +func (m *Map) marshalValue(data interface{}) (internal.Pointer, error) { + if m.typ.hasPerCPUValue() { + return marshalPerCPUValue(data, int(m.valueSize)) + } + + var ( + buf []byte + err error + ) + + switch value := data.(type) { + case *Map: + if !m.typ.canStoreMap() { + return internal.Pointer{}, fmt.Errorf("can't store map in %s", m.typ) + } + buf, err = marshalMap(value, int(m.valueSize)) + + case *Program: + if !m.typ.canStoreProgram() { + return internal.Pointer{}, fmt.Errorf("can't store program in %s", m.typ) + } + buf, err = marshalProgram(value, int(m.valueSize)) + + default: + return marshalPtr(data, int(m.valueSize)) + } + + if err != nil { + return internal.Pointer{}, err + } + + return internal.NewSlicePointer(buf), nil +} + +func (m *Map) unmarshalValue(value interface{}, buf []byte) error { + if buf == nil { + // This is from a makeBuffer call, nothing do do here. + return nil + } + + if m.typ.hasPerCPUValue() { + return unmarshalPerCPUValue(value, int(m.valueSize), buf) + } + + switch value := value.(type) { + case **Map: + if !m.typ.canStoreMap() { + return fmt.Errorf("can't read a map from %s", m.typ) + } + + other, err := unmarshalMap(buf) + if err != nil { + return err + } + + (*value).Close() + *value = other + return nil + + case *Map: + if !m.typ.canStoreMap() { + return fmt.Errorf("can't read a map from %s", m.typ) + } + return errors.New("require pointer to *Map") + + case **Program: + if !m.typ.canStoreProgram() { + return fmt.Errorf("can't read a program from %s", m.typ) + } + + other, err := unmarshalProgram(buf) + if err != nil { + return err + } + + (*value).Close() + *value = other + return nil + + case *Program: + if !m.typ.canStoreProgram() { + return fmt.Errorf("can't read a program from %s", m.typ) + } + return errors.New("require pointer to *Program") + } + + return unmarshalBytes(value, buf) +} + // LoadPinnedMap load a Map from a BPF file. func LoadPinnedMap(fileName string) (*Map, error) { fd, err := internal.BPFObjGet(fileName) @@ -730,19 +792,22 @@ func LoadPinnedMap(fileName string) (*Map, error) { return newMapFromFD(fd) } +// unmarshalMap creates a map from a map ID encoded in host endianness. func unmarshalMap(buf []byte) (*Map, error) { if len(buf) != 4 { return nil, errors.New("map id requires 4 byte value") } - // Looking up an entry in a nested map or prog array returns an id, - // not an fd. id := internal.NativeEndian.Uint32(buf) return NewMapFromID(MapID(id)) } -// MarshalBinary implements BinaryMarshaler. -func (m *Map) MarshalBinary() ([]byte, error) { +// marshalMap marshals the fd of a map into a buffer in host endianness. +func marshalMap(m *Map, length int) ([]byte, error) { + if length != 4 { + return nil, fmt.Errorf("can't marshal map to %d bytes", length) + } + fd, err := m.fd.Value() if err != nil { return nil, err @@ -877,7 +942,7 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { return false } - mi.err = unmarshalBytes(keyOut, nextBytes) + mi.err = mi.target.unmarshalKey(keyOut, nextBytes) return mi.err == nil } diff --git a/map_test.go b/map_test.go index 1fe1d2e73..893616b5c 100644 --- a/map_test.go +++ b/map_test.go @@ -196,23 +196,13 @@ func TestMapQueue(t *testing.T) { } defer m.Close() - if err := m.Put(nil, uint32(42)); err != nil { - t.Fatal("Can't put 42:", err) - } - - if err := m.Put(nil, uint32(4242)); err != nil { - t.Fatal("Can't put 4242:", err) + for _, v := range []uint32{42, 4242} { + if err := m.Put(nil, v); err != nil { + t.Fatalf("Can't put %d: %s", v, err) + } } var v uint32 - if err := m.Lookup(nil, &v); err != nil { - t.Fatal("Can't lookup element:", err) - } - if v != 42 { - t.Error("Want value 42, got", v) - } - - v = 0 if err := m.LookupAndDelete(nil, &v); err != nil { t.Fatal("Can't lookup and delete element:", err) } @@ -220,16 +210,9 @@ func TestMapQueue(t *testing.T) { t.Error("Want value 42, got", v) } - if err := m.Lookup(nil, &v); err != nil { - t.Fatal("Can't lookup element:", err) - } - if v != 4242 { - t.Error("Want value 4242, got", v) - } - v = 0 - if err := m.LookupAndDelete(nil, &v); err != nil { - t.Fatal("Can't lookup and delete element:", err) + if err := m.LookupAndDelete(nil, unsafe.Pointer(&v)); err != nil { + t.Fatal("Can't lookup and delete element using unsafe.Pointer:", err) } if v != 4242 { t.Error("Want value 4242, got", v) diff --git a/marshalers.go b/marshalers.go index c0db2f6b0..3ea1021a8 100644 --- a/marshalers.go +++ b/marshalers.go @@ -13,14 +13,12 @@ import ( "github.com/cilium/ebpf/internal" ) +// marshalPtr converts an arbitrary value into a pointer suitable +// to be passed to the kernel. +// +// As an optimization, it returns the original value if it is an +// unsafe.Pointer. func marshalPtr(data interface{}, length int) (internal.Pointer, error) { - if data == nil { - if length == 0 { - return internal.NewPointer(nil), nil - } - return internal.Pointer{}, errors.New("can't use nil as key of map") - } - if ptr, ok := data.(unsafe.Pointer); ok { return internal.NewPointer(ptr), nil } @@ -33,6 +31,13 @@ func marshalPtr(data interface{}, length int) (internal.Pointer, error) { return internal.NewSlicePointer(buf), nil } +// marshalBytes converts an arbitrary value into a byte buffer. +// +// Prefer using Map.marshalKey and Map.marshalValue if possible, since +// those have special cases that allow more types to be encoded. +// +// Returns an error if the given value isn't representable in exactly +// length bytes. func marshalBytes(data interface{}, length int) (buf []byte, err error) { switch value := data.(type) { case encoding.BinaryMarshaler: @@ -43,6 +48,8 @@ func marshalBytes(data interface{}, length int) (buf []byte, err error) { buf = value case unsafe.Pointer: err = errors.New("can't marshal from unsafe.Pointer") + case Map, *Map, Program, *Program: + err = fmt.Errorf("can't marshal %T", value) default: var wr bytes.Buffer err = binary.Write(&wr, internal.NativeEndian, value) @@ -70,6 +77,10 @@ func makeBuffer(dst interface{}, length int) (internal.Pointer, []byte) { return internal.NewSlicePointer(buf), buf } +// unmarshalBytes converts a byte buffer into an arbitrary value. +// +// Prefer using Map.unmarshalKey and Map.unmarshalValue if possible, since +// those have special cases that allow more types to be encoded. func unmarshalBytes(data interface{}, buf []byte) error { switch value := data.(type) { case unsafe.Pointer: @@ -83,6 +94,8 @@ func unmarshalBytes(data interface{}, buf []byte) error { copy(dst, buf) runtime.KeepAlive(value) return nil + case Map, *Map, Program, *Program: + return fmt.Errorf("can't unmarshal into %T", value) case encoding.BinaryUnmarshaler: return value.UnmarshalBinary(buf) case *string: diff --git a/prog.go b/prog.go index e7ebe6fb8..a1cb1ab49 100644 --- a/prog.go +++ b/prog.go @@ -496,8 +496,11 @@ func unmarshalProgram(buf []byte) (*Program, error) { return NewProgramFromID(ProgramID(id)) } -// MarshalBinary implements BinaryMarshaler. -func (p *Program) MarshalBinary() ([]byte, error) { +func marshalProgram(p *Program, length int) ([]byte, error) { + if length != 4 { + return nil, fmt.Errorf("can't marshal program to %d bytes", length) + } + value, err := p.fd.Value() if err != nil { return nil, err diff --git a/types.go b/types.go index dd82dfd4d..53e27ee8c 100644 --- a/types.go +++ b/types.go @@ -85,10 +85,19 @@ const ( // hasPerCPUValue returns true if the Map stores a value per CPU. func (mt MapType) hasPerCPUValue() bool { - if mt == PerCPUHash || mt == PerCPUArray || mt == LRUCPUHash { - return true - } - return false + return mt == PerCPUHash || mt == PerCPUArray || mt == LRUCPUHash +} + +// canStoreMap returns true if the map type accepts a map fd +// for update and returns a map id for lookup. +func (mt MapType) canStoreMap() bool { + return mt == ArrayOfMaps || mt == HashOfMaps +} + +// canStoreProgram returns true if the map type accepts a program fd +// for update and returns a program id for lookup. +func (mt MapType) canStoreProgram() bool { + return mt == ProgramArray } // ProgramType of the eBPF program