Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ FROM base AS development
ARG GOPROXY
ENV GOPROXY ${GOPROXY}

RUN apk add --update openssl openssh-client util-linux setpriv
RUN apk add --update openssl build-base binutils-gold openssh-client util-linux setpriv
RUN go install github.com/air-verse/air@v1.62 && \
go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.11.3

Expand Down
2 changes: 1 addition & 1 deletion api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ARG GOPROXY
ARG MJML_VERSION
ENV GOPROXY=${GOPROXY}

RUN apk add --update openssl build-base docker-cli npm
RUN apk add --update openssl build-base binutils-gold docker-cli npm
RUN npm install -g mjml@${MJML_VERSION}
RUN go install github.com/air-verse/air@v1.62 && \
go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.11.3 && \
Expand Down
5 changes: 5 additions & 0 deletions api/services/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var DeviceFilterFields = query.NewFieldConstraints(map[string][]string{
"info.platform": {"contains", "eq", "ne"},
"tags.name": {"contains", "eq"},
"online": {"bool", "eq"},
"custom_fields": {"contains"},
})

// DeviceSortFields is the set of field names accepted in the sort_by query
Expand Down Expand Up @@ -380,6 +381,10 @@ func (s *service) UpdateDevice(ctx context.Context, req *requests.DeviceUpdate)
device.Name = strings.ToLower(req.Name)
}

if req.CustomFields != nil {
device.CustomFields = *req.CustomFields
}

if err := s.store.DeviceUpdate(ctx, device); err != nil { // nolint:revive
return err
}
Expand Down
117 changes: 117 additions & 0 deletions api/services/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2411,6 +2411,123 @@ func TestDeviceUpdate(t *testing.T) {
},
expected: nil,
},
{
description: "success when setting custom fields",
req: &requests.DeviceUpdate{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
TenantID: "00000000-0000-0000-0000-000000000000",
Name: "existingname",
CustomFields: &map[string]string{"env": "production", "owner": "team-a"},
},
requiredMocks: func(ctx context.Context) {
device := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
}
updatedDevice := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
CustomFields: map[string]string{"env": "production", "owner": "team-a"},
}
queryOptionsMock.
On("InNamespace", "00000000-0000-0000-0000-000000000000").
Return(nil).
Once()
storeMock.
On("DeviceResolve", ctx, store.DeviceUIDResolver, "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e", mock.AnythingOfType("store.QueryOption")).
Return(device, nil).
Once()
// Distinct clears Name when req.Name == device.Name
storeMock.
On("DeviceConflicts", ctx, &models.DeviceConflicts{Name: ""}).
Return([]string{}, false, nil).
Once()
storeMock.
On("DeviceUpdate", ctx, updatedDevice).
Return(nil).
Once()
},
expected: nil,
},
{
description: "success when clearing custom fields with empty map",
req: &requests.DeviceUpdate{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
TenantID: "00000000-0000-0000-0000-000000000000",
Name: "existingname",
CustomFields: &map[string]string{},
},
requiredMocks: func(ctx context.Context) {
device := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
CustomFields: map[string]string{"env": "production"},
}
updatedDevice := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
CustomFields: map[string]string{},
}
queryOptionsMock.
On("InNamespace", "00000000-0000-0000-0000-000000000000").
Return(nil).
Once()
storeMock.
On("DeviceResolve", ctx, store.DeviceUIDResolver, "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e", mock.AnythingOfType("store.QueryOption")).
Return(device, nil).
Once()
// Distinct clears Name when req.Name == device.Name
storeMock.
On("DeviceConflicts", ctx, &models.DeviceConflicts{Name: ""}).
Return([]string{}, false, nil).
Once()
storeMock.
On("DeviceUpdate", ctx, updatedDevice).
Return(nil).
Once()
},
expected: nil,
},
{
description: "does not modify custom fields when CustomFields is nil",
req: &requests.DeviceUpdate{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
TenantID: "00000000-0000-0000-0000-000000000000",
Name: "existingname",
CustomFields: nil,
},
requiredMocks: func(ctx context.Context) {
device := &models.Device{
UID: "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e",
Name: "existingname",
DisconnectedAt: &now,
CustomFields: map[string]string{"env": "production"},
}
queryOptionsMock.
On("InNamespace", "00000000-0000-0000-0000-000000000000").
Return(nil).
Once()
storeMock.
On("DeviceResolve", ctx, store.DeviceUIDResolver, "d6c6a5e97217bbe4467eae46ab004695a766c5c43f70b95efd4b6a4d32b33c6e", mock.AnythingOfType("store.QueryOption")).
Return(device, nil).
Once()
// Distinct clears Name when req.Name == device.Name
storeMock.
On("DeviceConflicts", ctx, &models.DeviceConflicts{Name: ""}).
Return([]string{}, false, nil).
Once()
// device passed unchanged — CustomFields still has "env":"production"
storeMock.
On("DeviceUpdate", ctx, device).
Return(nil).
Once()
},
expected: nil,
},
}

service := NewService(storeMock, privateKey, publicKey, storecache.NewNullCache(), clientMock)
Expand Down
40 changes: 40 additions & 0 deletions api/store/mongo/internal/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,43 @@ func fromLt(value interface{}) (bson.M, error) {
func fromNe(value interface{}) (bson.M, error) {
return bson.M{"$ne": value}, nil
}

// ParseCustomFieldsFilter builds a MongoDB $match condition that searches across all values
// of the custom_fields document. Only "contains" with a string value is supported.
func ParseCustomFieldsFilter(fp *query.FilterProperty) (bson.M, bool, error) {
if fp.Operator != "contains" {
return nil, false, nil
}

v, ok := fp.Value.(string)
if !ok {
return nil, false, errors.New("custom_fields contains filter requires a string value")
}

// Use $objectToArray to iterate over all values in the custom_fields map,
// then check if any value matches the regex.
condition := bson.M{
"$expr": bson.M{
"$gt": bson.A{
bson.M{
"$size": bson.M{
"$filter": bson.M{
"input": bson.M{"$objectToArray": bson.M{"$ifNull": bson.A{"$custom_fields", bson.M{}}}},
"as": "cf",
"cond": bson.M{
"$regexMatch": bson.M{
"input": "$$cf.v",
"regex": v,
"options": "i",
},
},
},
},
},
0,
},
},
}

return condition, true, nil
}
104 changes: 104 additions & 0 deletions api/store/mongo/internal/filters_custom_fields_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package internal

import (
"testing"

"github.com/shellhub-io/shellhub/pkg/api/query"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"
)

func TestParseCustomFieldsFilter(t *testing.T) {
cases := []struct {
description string
fp *query.FilterProperty
wantOk bool
wantErr bool
checkResult func(t *testing.T, result bson.M)
}{
{
description: "returns not-ok for unsupported operator eq",
fp: &query.FilterProperty{Name: "custom_fields", Operator: "eq", Value: "prod"},
wantOk: false,
wantErr: false,
checkResult: func(t *testing.T, result bson.M) {
assert.Nil(t, result)
},
},
{
description: "returns error when value is not a string",
fp: &query.FilterProperty{Name: "custom_fields", Operator: "contains", Value: 42},
wantOk: false,
wantErr: true,
checkResult: func(t *testing.T, result bson.M) {
assert.Nil(t, result)
},
},
{
description: "returns $expr condition for contains with string value",
fp: &query.FilterProperty{Name: "custom_fields", Operator: "contains", Value: "production"},
wantOk: true,
wantErr: false,
checkResult: func(t *testing.T, result bson.M) {
require.NotNil(t, result)
// Top-level key must be $expr
exprRaw, ok := result["$expr"]
require.True(t, ok, "result must have $expr key")

expr, ok := exprRaw.(bson.M)
require.True(t, ok)

// $expr.$gt must exist
gtRaw, ok := expr["$gt"]
require.True(t, ok, "$expr must have $gt")

gt, ok := gtRaw.(bson.A)
require.True(t, ok)
require.Len(t, gt, 2)

// Second element of $gt must be 0 (threshold)
assert.Equal(t, 0, gt[1])

// First element is the $size expression
sizeExpr, ok := gt[0].(bson.M)
require.True(t, ok)
_, hasSz := sizeExpr["$size"]
assert.True(t, hasSz, "$gt[0] must be a $size expression")
},
},
{
description: "regex contains the search value",
fp: &query.FilterProperty{Name: "custom_fields", Operator: "contains", Value: "team-a"},
wantOk: true,
wantErr: false,
checkResult: func(t *testing.T, result bson.M) {
require.NotNil(t, result)
// Walk down to the $regexMatch input
expr := result["$expr"].(bson.M)
gt := expr["$gt"].(bson.A)
sizeExpr := gt[0].(bson.M)
filterExpr := sizeExpr["$size"].(bson.M)
filterMap := filterExpr["$filter"].(bson.M)
cond := filterMap["cond"].(bson.M)
regexMatch := cond["$regexMatch"].(bson.M)

assert.Equal(t, "team-a", regexMatch["regex"])
assert.Equal(t, "i", regexMatch["options"])
},
},
}

for _, tc := range cases {
t.Run(tc.description, func(t *testing.T) {
result, ok, err := ParseCustomFieldsFilter(tc.fp)
assert.Equal(t, tc.wantOk, ok)
if tc.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
tc.checkResult(t, result)
})
}
}
12 changes: 12 additions & 0 deletions api/store/mongo/query-options.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,18 @@ func (*queryOptions) Match(filters *query.Filters) store.QueryOption {
return query.ErrFilterInvalid
}

if param.Name == "custom_fields" {
condition, ok, err := internal.ParseCustomFieldsFilter(param)
switch {
case err != nil:
return query.ErrFilterPropertyInvalid
case ok:
conditions = append(conditions, condition)
}

continue
}

property, ok, err := internal.ParseFilterProperty(param)
switch {
case err != nil:
Expand Down
45 changes: 24 additions & 21 deletions api/store/pg/entity/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,28 @@ import (
type Device struct {
bun.BaseModel `bun:"table:devices"`

ID string `bun:"id,pk"`
NamespaceID string `bun:"namespace_id,type:uuid"`
CreatedAt time.Time `bun:"created_at"`
UpdatedAt time.Time `bun:"updated_at"`
RemovedAt *time.Time `bun:"removed_at"`
LastSeen time.Time `bun:"last_seen"`
DisconnectedAt time.Time `bun:"disconnected_at,nullzero"`
Online bool `bun:",scanonly"`
Acceptable bool `bun:",scanonly"`
Status string `bun:"status"`
StatusUpdatedAt time.Time `bun:"status_updated_at"`
Name string `bun:"name"`
MAC string `bun:"mac"`
PublicKey string `bun:"public_key"`
Identifier string `bun:"identifier"`
PrettyName string `bun:"pretty_name"`
Version string `bun:"version"`
Arch string `bun:"arch"`
Platform string `bun:"platform"`
Longitude float64 `bun:"longitude,type:numeric"`
Latitude float64 `bun:"latitude,type:numeric"`
ID string `bun:"id,pk"`
NamespaceID string `bun:"namespace_id,type:uuid"`
CreatedAt time.Time `bun:"created_at"`
UpdatedAt time.Time `bun:"updated_at"`
RemovedAt *time.Time `bun:"removed_at"`
LastSeen time.Time `bun:"last_seen"`
DisconnectedAt time.Time `bun:"disconnected_at,nullzero"`
Online bool `bun:",scanonly"`
Acceptable bool `bun:",scanonly"`
Status string `bun:"status"`
StatusUpdatedAt time.Time `bun:"status_updated_at"`
Name string `bun:"name"`
MAC string `bun:"mac"`
PublicKey string `bun:"public_key"`
Identifier string `bun:"identifier"`
PrettyName string `bun:"pretty_name"`
Version string `bun:"version"`
Arch string `bun:"arch"`
Platform string `bun:"platform"`
Longitude float64 `bun:"longitude,type:numeric"`
Latitude float64 `bun:"latitude,type:numeric"`
CustomFields map[string]string `bun:"custom_fields,type:jsonb,nullzero,default:'{}'"`

Namespace *Namespace `bun:"rel:belongs-to,join:namespace_id=id"`
Tags []*Tag `bun:"m2m:device_tags,join:Device=Tag"`
Expand All @@ -54,6 +55,7 @@ func DeviceFromModel(model *models.Device) *Device {
StatusUpdatedAt: model.StatusUpdatedAt,
Name: model.Name,
PublicKey: model.PublicKey,
CustomFields: model.CustomFields,
Tags: []*Tag{},
}

Expand Down Expand Up @@ -112,6 +114,7 @@ func DeviceToModel(entity *Device) *models.Device {
Namespace: "",
DisconnectedAt: nil,
RemoteAddr: "",
CustomFields: entity.CustomFields,
Taggable: models.Taggable{
Tags: []models.Tag{},
},
Expand Down
Loading
Loading