Skip to content

Commit

Permalink
map: distinguish between keys and values when marshaling
Browse files Browse the repository at this point in the history
Marshaling and unmarshaling map keys and values is quite complicated:

* Queues allow nil keys
* Some maps allow storing map and program fds
* But looking them up returns map and program IDs, not fds
* We allow passing unsafe.Pointer for "zero copy" lookups

The current implementation isn't careful enough about these
constraints. For example, Map implements BinaryMarshaler which
encodes the map fd as a host endian uint32. This representation
only makes sense when inserting into an ArrayOfMaps or similar.
Since we don't distinguish between keys and values, you can use
a *Map as a key and get the same behaviour, except that the
key ends up being the value of the file descriptor. This doesn't
make a lot of sense.

Add helpers on Map that should be used when marshaling or unmarshaling
keys and values. Re-using the same code everywhere also fixes
an inconsistency with LookupAndDelete, which currently doesn't
allow unsafe.Pointer.
  • Loading branch information
lmb authored and nathanjsweet committed Feb 3, 2021
1 parent 0e88940 commit e21b849
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 98 deletions.
189 changes: 127 additions & 62 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,46 +420,7 @@ func (m *Map) Lookup(key, valueOut interface{}) error {
return err
}

if valueBytes == nil {
return nil
}

if m.typ.hasPerCPUValue() {
return unmarshalPerCPUValue(valueOut, int(m.valueSize), valueBytes)
}

switch value := valueOut.(type) {
case **Map:
m, err := unmarshalMap(valueBytes)
if err != nil {
return err
}

(*value).Close()
*value = m
return nil
case *Map:
return fmt.Errorf("can't unmarshal into %T, need %T", value, (**Map)(nil))
case Map:
return fmt.Errorf("can't unmarshal into %T, need %T", value, (**Map)(nil))

case **Program:
p, err := unmarshalProgram(valueBytes)
if err != nil {
return err
}

(*value).Close()
*value = p
return nil
case *Program:
return fmt.Errorf("can't unmarshal into %T, need %T", value, (**Program)(nil))
case Program:
return fmt.Errorf("can't unmarshal into %T, need %T", value, (**Program)(nil))

default:
return unmarshalBytes(valueOut, valueBytes)
}
return m.unmarshalValue(valueOut, valueBytes)
}

// LookupAndDelete retrieves and deletes a value from a Map.
Expand All @@ -468,7 +429,7 @@ func (m *Map) Lookup(key, valueOut interface{}) error {
func (m *Map) LookupAndDelete(key, valueOut interface{}) error {
valuePtr, valueBytes := makeBuffer(valueOut, m.fullValueSize)

keyPtr, err := marshalPtr(key, int(m.keySize))
keyPtr, err := m.marshalKey(key)
if err != nil {
return fmt.Errorf("can't marshal key: %w", err)
}
Expand All @@ -477,7 +438,7 @@ func (m *Map) LookupAndDelete(key, valueOut interface{}) error {
return fmt.Errorf("lookup and delete failed: %w", err)
}

return unmarshalBytes(valueOut, valueBytes)
return m.unmarshalValue(valueOut, valueBytes)
}

// LookupBytes gets a value from Map.
Expand All @@ -496,7 +457,7 @@ func (m *Map) LookupBytes(key interface{}) ([]byte, error) {
}

func (m *Map) lookup(key interface{}, valueOut internal.Pointer) error {
keyPtr, err := marshalPtr(key, int(m.keySize))
keyPtr, err := m.marshalKey(key)
if err != nil {
return fmt.Errorf("can't marshal key: %w", err)
}
Expand Down Expand Up @@ -530,17 +491,12 @@ func (m *Map) Put(key, value interface{}) error {

// Update changes the value of a key.
func (m *Map) Update(key, value interface{}, flags MapUpdateFlags) error {
keyPtr, err := marshalPtr(key, int(m.keySize))
keyPtr, err := m.marshalKey(key)
if err != nil {
return fmt.Errorf("can't marshal key: %w", err)
}

var valuePtr internal.Pointer
if m.typ.hasPerCPUValue() {
valuePtr, err = marshalPerCPUValue(value, int(m.valueSize))
} else {
valuePtr, err = marshalPtr(value, int(m.valueSize))
}
valuePtr, err := m.marshalValue(value)
if err != nil {
return fmt.Errorf("can't marshal value: %w", err)
}
Expand All @@ -556,7 +512,7 @@ func (m *Map) Update(key, value interface{}, flags MapUpdateFlags) error {
//
// Returns ErrKeyNotExist if the key does not exist.
func (m *Map) Delete(key interface{}) error {
keyPtr, err := marshalPtr(key, int(m.keySize))
keyPtr, err := m.marshalKey(key)
if err != nil {
return fmt.Errorf("can't marshal key: %w", err)
}
Expand All @@ -579,11 +535,7 @@ func (m *Map) NextKey(key, nextKeyOut interface{}) error {
return err
}

if nextKeyBytes == nil {
return nil
}

if err := unmarshalBytes(nextKeyOut, nextKeyBytes); err != nil {
if err := m.unmarshalKey(nextKeyOut, nextKeyBytes); err != nil {
return fmt.Errorf("can't unmarshal next key: %w", err)
}
return nil
Expand Down Expand Up @@ -615,7 +567,7 @@ func (m *Map) nextKey(key interface{}, nextKeyOut internal.Pointer) error {
)

if key != nil {
keyPtr, err = marshalPtr(key, int(m.keySize))
keyPtr, err = m.marshalKey(key)
if err != nil {
return fmt.Errorf("can't marshal key: %w", err)
}
Expand Down Expand Up @@ -720,6 +672,116 @@ func (m *Map) populate(contents []MapKV) error {
return nil
}

func (m *Map) marshalKey(data interface{}) (internal.Pointer, error) {
if data == nil {
if m.keySize == 0 {
// Queues have a key length of zero, so passing nil here is valid.
return internal.NewPointer(nil), nil
}
return internal.Pointer{}, errors.New("can't use nil as key of map")
}

return marshalPtr(data, int(m.keySize))
}

func (m *Map) unmarshalKey(data interface{}, buf []byte) error {
if buf == nil {
// This is from a makeBuffer call, nothing do do here.
return nil
}

return unmarshalBytes(data, buf)
}

func (m *Map) marshalValue(data interface{}) (internal.Pointer, error) {
if m.typ.hasPerCPUValue() {
return marshalPerCPUValue(data, int(m.valueSize))
}

var (
buf []byte
err error
)

switch value := data.(type) {
case *Map:
if !m.typ.canStoreMap() {
return internal.Pointer{}, fmt.Errorf("can't store map in %s", m.typ)
}
buf, err = marshalMap(value, int(m.valueSize))

case *Program:
if !m.typ.canStoreProgram() {
return internal.Pointer{}, fmt.Errorf("can't store program in %s", m.typ)
}
buf, err = marshalProgram(value, int(m.valueSize))

default:
return marshalPtr(data, int(m.valueSize))
}

if err != nil {
return internal.Pointer{}, err
}

return internal.NewSlicePointer(buf), nil
}

func (m *Map) unmarshalValue(value interface{}, buf []byte) error {
if buf == nil {
// This is from a makeBuffer call, nothing do do here.
return nil
}

if m.typ.hasPerCPUValue() {
return unmarshalPerCPUValue(value, int(m.valueSize), buf)
}

switch value := value.(type) {
case **Map:
if !m.typ.canStoreMap() {
return fmt.Errorf("can't read a map from %s", m.typ)
}

other, err := unmarshalMap(buf)
if err != nil {
return err
}

(*value).Close()
*value = other
return nil

case *Map:
if !m.typ.canStoreMap() {
return fmt.Errorf("can't read a map from %s", m.typ)
}
return errors.New("require pointer to *Map")

case **Program:
if !m.typ.canStoreProgram() {
return fmt.Errorf("can't read a program from %s", m.typ)
}

other, err := unmarshalProgram(buf)
if err != nil {
return err
}

(*value).Close()
*value = other
return nil

case *Program:
if !m.typ.canStoreProgram() {
return fmt.Errorf("can't read a program from %s", m.typ)
}
return errors.New("require pointer to *Program")
}

return unmarshalBytes(value, buf)
}

// LoadPinnedMap load a Map from a BPF file.
func LoadPinnedMap(fileName string) (*Map, error) {
fd, err := internal.BPFObjGet(fileName)
Expand All @@ -730,19 +792,22 @@ func LoadPinnedMap(fileName string) (*Map, error) {
return newMapFromFD(fd)
}

// unmarshalMap creates a map from a map ID encoded in host endianness.
func unmarshalMap(buf []byte) (*Map, error) {
if len(buf) != 4 {
return nil, errors.New("map id requires 4 byte value")
}

// Looking up an entry in a nested map or prog array returns an id,
// not an fd.
id := internal.NativeEndian.Uint32(buf)
return NewMapFromID(MapID(id))
}

// MarshalBinary implements BinaryMarshaler.
func (m *Map) MarshalBinary() ([]byte, error) {
// marshalMap marshals the fd of a map into a buffer in host endianness.
func marshalMap(m *Map, length int) ([]byte, error) {
if length != 4 {
return nil, fmt.Errorf("can't marshal map to %d bytes", length)
}

fd, err := m.fd.Value()
if err != nil {
return nil, err
Expand Down Expand Up @@ -877,7 +942,7 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool {
return false
}

mi.err = unmarshalBytes(keyOut, nextBytes)
mi.err = mi.target.unmarshalKey(keyOut, nextBytes)
return mi.err == nil
}

Expand Down
29 changes: 6 additions & 23 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,40 +196,23 @@ func TestMapQueue(t *testing.T) {
}
defer m.Close()

if err := m.Put(nil, uint32(42)); err != nil {
t.Fatal("Can't put 42:", err)
}

if err := m.Put(nil, uint32(4242)); err != nil {
t.Fatal("Can't put 4242:", err)
for _, v := range []uint32{42, 4242} {
if err := m.Put(nil, v); err != nil {
t.Fatalf("Can't put %d: %s", v, err)
}
}

var v uint32
if err := m.Lookup(nil, &v); err != nil {
t.Fatal("Can't lookup element:", err)
}
if v != 42 {
t.Error("Want value 42, got", v)
}

v = 0
if err := m.LookupAndDelete(nil, &v); err != nil {
t.Fatal("Can't lookup and delete element:", err)
}
if v != 42 {
t.Error("Want value 42, got", v)
}

if err := m.Lookup(nil, &v); err != nil {
t.Fatal("Can't lookup element:", err)
}
if v != 4242 {
t.Error("Want value 4242, got", v)
}

v = 0
if err := m.LookupAndDelete(nil, &v); err != nil {
t.Fatal("Can't lookup and delete element:", err)
if err := m.LookupAndDelete(nil, unsafe.Pointer(&v)); err != nil {
t.Fatal("Can't lookup and delete element using unsafe.Pointer:", err)
}
if v != 4242 {
t.Error("Want value 4242, got", v)
Expand Down
27 changes: 20 additions & 7 deletions marshalers.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@ import (
"github.com/cilium/ebpf/internal"
)

// marshalPtr converts an arbitrary value into a pointer suitable
// to be passed to the kernel.
//
// As an optimization, it returns the original value if it is an
// unsafe.Pointer.
func marshalPtr(data interface{}, length int) (internal.Pointer, error) {
if data == nil {
if length == 0 {
return internal.NewPointer(nil), nil
}
return internal.Pointer{}, errors.New("can't use nil as key of map")
}

if ptr, ok := data.(unsafe.Pointer); ok {
return internal.NewPointer(ptr), nil
}
Expand All @@ -33,6 +31,13 @@ func marshalPtr(data interface{}, length int) (internal.Pointer, error) {
return internal.NewSlicePointer(buf), nil
}

// marshalBytes converts an arbitrary value into a byte buffer.
//
// Prefer using Map.marshalKey and Map.marshalValue if possible, since
// those have special cases that allow more types to be encoded.
//
// Returns an error if the given value isn't representable in exactly
// length bytes.
func marshalBytes(data interface{}, length int) (buf []byte, err error) {
switch value := data.(type) {
case encoding.BinaryMarshaler:
Expand All @@ -43,6 +48,8 @@ func marshalBytes(data interface{}, length int) (buf []byte, err error) {
buf = value
case unsafe.Pointer:
err = errors.New("can't marshal from unsafe.Pointer")
case Map, *Map, Program, *Program:
err = fmt.Errorf("can't marshal %T", value)
default:
var wr bytes.Buffer
err = binary.Write(&wr, internal.NativeEndian, value)
Expand Down Expand Up @@ -70,6 +77,10 @@ func makeBuffer(dst interface{}, length int) (internal.Pointer, []byte) {
return internal.NewSlicePointer(buf), buf
}

// unmarshalBytes converts a byte buffer into an arbitrary value.
//
// Prefer using Map.unmarshalKey and Map.unmarshalValue if possible, since
// those have special cases that allow more types to be encoded.
func unmarshalBytes(data interface{}, buf []byte) error {
switch value := data.(type) {
case unsafe.Pointer:
Expand All @@ -83,6 +94,8 @@ func unmarshalBytes(data interface{}, buf []byte) error {
copy(dst, buf)
runtime.KeepAlive(value)
return nil
case Map, *Map, Program, *Program:
return fmt.Errorf("can't unmarshal into %T", value)
case encoding.BinaryUnmarshaler:
return value.UnmarshalBinary(buf)
case *string:
Expand Down
Loading

0 comments on commit e21b849

Please sign in to comment.