Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ FROM base AS development
ARG GOPROXY
ENV GOPROXY ${GOPROXY}

RUN apk add --update openssl openssh-client util-linux setpriv
RUN apk add --update openssl build-base binutils-gold openssh-client util-linux setpriv
RUN go install github.com/air-verse/air@v1.62 && \
go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.11.3

Expand Down
2 changes: 1 addition & 1 deletion api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ARG GOPROXY
ARG MJML_VERSION
ENV GOPROXY=${GOPROXY}

RUN apk add --update openssl build-base docker-cli npm
RUN apk add --update openssl build-base binutils-gold docker-cli npm
RUN npm install -g mjml@${MJML_VERSION}
RUN go install github.com/air-verse/air@v1.62 && \
go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.11.3 && \
Expand Down
5 changes: 5 additions & 0 deletions api/services/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var DeviceFilterFields = query.NewFieldConstraints(map[string][]string{
"info.platform": {"contains", "eq", "ne"},
"tags.name": {"contains", "eq"},
"online": {"bool", "eq"},
"custom_fields": {"contains"},
})

// DeviceSortFields is the set of field names accepted in the sort_by query
Expand Down Expand Up @@ -380,6 +381,10 @@ func (s *service) UpdateDevice(ctx context.Context, req *requests.DeviceUpdate)
device.Name = strings.ToLower(req.Name)
}

if req.CustomFields != nil {
device.CustomFields = *req.CustomFields
}

if err := s.store.DeviceUpdate(ctx, device); err != nil { // nolint:revive
return err
}
Expand Down
117 changes: 117 additions & 0 deletions api/services/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2411,6 +2411,123 @@ func TestDeviceUpdate(t *testing.T) {
},
expected: nil,
},
{
description: "success when setting custom fields",
req: &requests.DeviceUpdate{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
TenantID: "00000000-0000-0000-0000-000000000000",
Name: "existingname",
CustomFields: &map[string]string{"env": "production", "owner": "team-a"},
},
requiredMocks: func(ctx context.Context) {
device := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
}
updatedDevice := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
CustomFields: map[string]string{"env": "production", "owner": "team-a"},
}
queryOptionsMock.
On("InNamespace", "00000000-0000-0000-0000-000000000000").
Return(nil).
Once()
storeMock.
On("DeviceResolve", ctx, store.DeviceUIDResolver, "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e", mock.AnythingOfType("store.QueryOption")).
Return(device, nil).
Once()
// Distinct clears Name when req.Name == device.Name
storeMock.
On("DeviceConflicts", ctx, &models.DeviceConflicts{Name: ""}).
Return([]string{}, false, nil).
Once()
storeMock.
On("DeviceUpdate", ctx, updatedDevice).
Return(nil).
Once()
},
expected: nil,
},
{
description: "success when clearing custom fields with empty map",
req: &requests.DeviceUpdate{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
TenantID: "00000000-0000-0000-0000-000000000000",
Name: "existingname",
CustomFields: &map[string]string{},
},
requiredMocks: func(ctx context.Context) {
device := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
CustomFields: map[string]string{"env": "production"},
}
updatedDevice := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
CustomFields: map[string]string{},
}
queryOptionsMock.
On("InNamespace", "00000000-0000-0000-0000-000000000000").
Return(nil).
Once()
storeMock.
On("DeviceResolve", ctx, store.DeviceUIDResolver, "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e", mock.AnythingOfType("store.QueryOption")).
Return(device, nil).
Once()
// Distinct clears Name when req.Name == device.Name
storeMock.
On("DeviceConflicts", ctx, &models.DeviceConflicts{Name: ""}).
Return([]string{}, false, nil).
Once()
storeMock.
On("DeviceUpdate", ctx, updatedDevice).
Return(nil).
Once()
},
expected: nil,
},
{
description: "does not modify custom fields when CustomFields is nil",
req: &requests.DeviceUpdate{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
TenantID: "00000000-0000-0000-0000-000000000000",
Name: "existingname",
CustomFields: nil,
},
requiredMocks: func(ctx context.Context) {
device := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
CustomFields: map[string]string{"env": "production"},
}
queryOptionsMock.
On("InNamespace", "00000000-0000-0000-0000-000000000000").
Return(nil).
Once()
storeMock.
On("DeviceResolve", ctx, store.DeviceUIDResolver, "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e", mock.AnythingOfType("store.QueryOption")).
Return(device, nil).
Once()
// Distinct clears Name when req.Name == device.Name
storeMock.
On("DeviceConflicts", ctx, &models.DeviceConflicts{Name: ""}).
Return([]string{}, false, nil).
Once()
// device passed unchanged — CustomFields still has "env":"production"
storeMock.
On("DeviceUpdate", ctx, device).
Return(nil).
Once()
},
expected: nil,
},
}

service := NewService(storeMock, privateKey, publicKey, storecache.NewNullCache(), clientMock)
Expand Down
40 changes: 40 additions & 0 deletions api/store/mongo/internal/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,43 @@ func fromLt(value interface{}) (bson.M, error) {
func fromNe(value interface{}) (bson.M, error) {
return bson.M{"$ne": value}, nil
}

// ParseCustomFieldsFilter builds a MongoDB $match condition that searches across all values
// of the custom_fields document. Only "contains" with a string value is supported.
func ParseCustomFieldsFilter(fp *query.FilterProperty) (bson.M, bool, error) {
if fp.Operator != "contains" {
return nil, false, nil
}

v, ok := fp.Value.(string)
if !ok {
return nil, false, errors.New("custom_fields contains filter requires a string value")
}

// Use $objectToArray to iterate over all values in the custom_fields map,
// then check if any value matches the regex.
condition := bson.M{
"$expr": bson.M{
"$gt": bson.A{
bson.M{
"$size": bson.M{
"$filter": bson.M{
"input": bson.M{"$objectToArray": bson.M{"$ifNull": bson.A{"$custom_fields", bson.M{}}}},
"as": "cf",
"cond": bson.M{
"$regexMatch": bson.M{
"input": "$$cf.v",
"regex": v,
"options": "i",
},
},
},
},
},
0,
},
},
}

return condition, true, nil
}
104 changes: 104 additions & 0 deletions api/store/mongo/internal/filters_custom_fields_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package internal

import (
"testing"

"github.com/shellhub-io/shellhub/pkg/api/query"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"
)

func TestParseCustomFieldsFilter(t *testing.T) {
cases := []struct {
description string
fp *query.FilterProperty
wantOk bool
wantErr bool
checkResult func(t *testing.T, result bson.M)
}{
{
description: "returns not-ok for unsupported operator eq",
fp: &query.FilterProperty{Name: "custom_fields", Operator: "eq", Value: "prod"},
wantOk: false,
wantErr: false,
checkResult: func(t *testing.T, result bson.M) {
assert.Nil(t, result)
},
},
{
description: "returns error when value is not a string",
fp: &query.FilterProperty{Name: "custom_fields", Operator: "contains", Value: 42},
wantOk: false,
wantErr: true,
checkResult: func(t *testing.T, result bson.M) {
assert.Nil(t, result)
},
},
{
description: "returns $expr condition for contains with string value",
fp: &query.FilterProperty{Name: "custom_fields", Operator: "contains", Value: "production"},
wantOk: true,
wantErr: false,
checkResult: func(t *testing.T, result bson.M) {
require.NotNil(t, result)
// Top-level key must be $expr
exprRaw, ok := result["$expr"]
require.True(t, ok, "result must have $expr key")

expr, ok := exprRaw.(bson.M)
require.True(t, ok)

// $expr.$gt must exist
gtRaw, ok := expr["$gt"]
require.True(t, ok, "$expr must have $gt")

gt, ok := gtRaw.(bson.A)
require.True(t, ok)
require.Len(t, gt, 2)

// Second element of $gt must be 0 (threshold)
assert.Equal(t, 0, gt[1])

// First element is the $size expression
sizeExpr, ok := gt[0].(bson.M)
require.True(t, ok)
_, hasSz := sizeExpr["$size"]
assert.True(t, hasSz, "$gt[0] must be a $size expression")
},
},
{
description: "regex contains the search value",
fp: &query.FilterProperty{Name: "custom_fields", Operator: "contains", Value: "team-a"},
wantOk: true,
wantErr: false,
checkResult: func(t *testing.T, result bson.M) {
require.NotNil(t, result)
// Walk down to the $regexMatch input
expr := result["$expr"].(bson.M)
gt := expr["$gt"].(bson.A)
sizeExpr := gt[0].(bson.M)
filterExpr := sizeExpr["$size"].(bson.M)
filterMap := filterExpr["$filter"].(bson.M)
cond := filterMap["cond"].(bson.M)
regexMatch := cond["$regexMatch"].(bson.M)

assert.Equal(t, "team-a", regexMatch["regex"])
assert.Equal(t, "i", regexMatch["options"])
},
},
}

for _, tc := range cases {
t.Run(tc.description, func(t *testing.T) {
result, ok, err := ParseCustomFieldsFilter(tc.fp)
assert.Equal(t, tc.wantOk, ok)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
tc.checkResult(t, result)
})
}
}
12 changes: 12 additions & 0 deletions api/store/mongo/query-options.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ func (*queryOptions) Match(filters *query.Filters) store.QueryOption {
return query.ErrFilterInvalid
}

if param.Name == "custom_fields" {
condition, ok, err := internal.ParseCustomFieldsFilter(param)
switch {
case err != nil:
return query.ErrFilterPropertyInvalid
case ok:
conditions = append(conditions, condition)
}

continue
}

property, ok, err := internal.ParseFilterProperty(param)
switch {
case err != nil:
Expand Down
45 changes: 24 additions & 21 deletions api/store/pg/entity/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,28 @@ import (
type Device struct {
bun.BaseModel `bun:"table:devices"`

ID string `bun:"id,pk"`
NamespaceID string `bun:"namespace_id,type:uuid"`
CreatedAt time.Time `bun:"created_at"`
UpdatedAt time.Time `bun:"updated_at"`
RemovedAt *time.Time `bun:"removed_at"`
LastSeen time.Time `bun:"last_seen"`
DisconnectedAt time.Time `bun:"disconnected_at,nullzero"`
Online bool `bun:",scanonly"`
Acceptable bool `bun:",scanonly"`
Status string `bun:"status"`
StatusUpdatedAt time.Time `bun:"status_updated_at"`
Name string `bun:"name"`
MAC string `bun:"mac"`
PublicKey string `bun:"public_key"`
Identifier string `bun:"identifier"`
PrettyName string `bun:"pretty_name"`
Version string `bun:"version"`
Arch string `bun:"arch"`
Platform string `bun:"platform"`
Longitude float64 `bun:"longitude,type:numeric"`
Latitude float64 `bun:"latitude,type:numeric"`
ID string `bun:"id,pk"`
NamespaceID string `bun:"namespace_id,type:uuid"`
CreatedAt time.Time `bun:"created_at"`
UpdatedAt time.Time `bun:"updated_at"`
RemovedAt *time.Time `bun:"removed_at"`
LastSeen time.Time `bun:"last_seen"`
DisconnectedAt time.Time `bun:"disconnected_at,nullzero"`
Online bool `bun:",scanonly"`
Acceptable bool `bun:",scanonly"`
Status string `bun:"status"`
StatusUpdatedAt time.Time `bun:"status_updated_at"`
Name string `bun:"name"`
MAC string `bun:"mac"`
PublicKey string `bun:"public_key"`
Identifier string `bun:"identifier"`
PrettyName string `bun:"pretty_name"`
Version string `bun:"version"`
Arch string `bun:"arch"`
Platform string `bun:"platform"`
Longitude float64 `bun:"longitude,type:numeric"`
Latitude float64 `bun:"latitude,type:numeric"`
CustomFields map[string]string `bun:"custom_fields,type:jsonb"`
Comment thread
danielgatis marked this conversation as resolved.
Outdated

Namespace *Namespace `bun:"rel:belongs-to,join:namespace_id=id"`
Tags []*Tag `bun:"m2m:device_tags,join:Device=Tag"`
Expand All @@ -54,6 +55,7 @@ func DeviceFromModel(model *models.Device) *Device {
StatusUpdatedAt: model.StatusUpdatedAt,
Name: model.Name,
PublicKey: model.PublicKey,
CustomFields: model.CustomFields,
Tags: []*Tag{},
}

Expand Down Expand Up @@ -112,6 +114,7 @@ func DeviceToModel(entity *Device) *models.Device {
Namespace: "",
DisconnectedAt: nil,
RemoteAddr: "",
CustomFields: entity.CustomFields,
Taggable: models.Taggable{
Tags: []models.Tag{},
},
Expand Down
Loading
Loading