diff --git a/service/go.mod b/service/go.mod index 9e3517a720..3bacd59681 100644 --- a/service/go.mod +++ b/service/go.mod @@ -56,6 +56,8 @@ require ( github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/itchyny/gojq v0.12.15 // indirect + github.com/itchyny/timefmt-go v0.1.5 // indirect ) replace ( diff --git a/service/go.sum b/service/go.sum index a96f07f104..17f5317310 100644 --- a/service/go.sum +++ b/service/go.sum @@ -180,6 +180,10 @@ github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/itchyny/gojq v0.12.15 h1:WC1Nxbx4Ifw5U2oQWACYz32JK8G9qxNtHzrvW4KEcqI= +github.com/itchyny/gojq v0.12.15/go.mod h1:uWAHCbCIla1jiNxmeT5/B5mOjSdfkCq6p8vxWg+BM10= +github.com/itchyny/timefmt-go v0.1.5 h1:G0INE2la8S6ru/ZI5JecgyzbbJNs5lG1RcBqa7Jm6GE= +github.com/itchyny/timefmt-go v0.1.5/go.mod h1:nEP7L+2YmAbT2kZ2HfSs1d8Xtw9LY8D2stDBckWakZ8= github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw= github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= diff --git a/service/internal/jqbuiltin/README.md b/service/internal/jqbuiltin/README.md new file mode 100644 index 0000000000..528fc17f47 --- /dev/null +++ b/service/internal/jqbuiltin/README.md @@ -0,0 +1,51 @@ +## Testing an OPA builtin with a rego query + +1. Set up your main.go to be the following +``` +func main() { + logLevel := &slog.LevelVar{} + logLevel.Set(slog.LevelDebug) + + opts := &slog.HandlerOptions{ + Level: logLevel, + } + logger := slog.New(slog.NewJSONHandler(os.Stdout, opts)) + + slog.SetDefault(logger) + + jqbuiltin.JQBuiltin() + + if err := cmd.RootCommand.Execute(); err != nil { + fmt.Println(err) + os.Exit(1) + } +} +``` + +2. Build the executable +``` +cd service +go build -o opa++ +``` + +3. Create an example rego file +``` +package sample + +my_json = { + "testing1": { + "testing2": { + "testing3": ["helloworld"] + } + } +} +req = ".testing1.testing2.testing3[]" + +res := jq.evaluate(my_json, req) +``` + +4. Perform the query +``` +./opa++ eval -d example.rego 'data.sample.res' +``` + diff --git a/service/internal/jqbuiltin/jq_builtin.go b/service/internal/jqbuiltin/jq_builtin.go new file mode 100644 index 0000000000..6530758a79 --- /dev/null +++ b/service/internal/jqbuiltin/jq_builtin.go @@ -0,0 +1,102 @@ +package jqbuiltin + +import ( + "bytes" + "encoding/json" + "log/slog" + "strconv" + + "github.com/itchyny/gojq" + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/rego" + "github.com/open-policy-agent/opa/types" +) + +func JQBuiltin() { + rego.RegisterBuiltin2(®o.Function{ + Name: "jq.evaluate", + Decl: types.NewFunction(types.Args(types.A, types.S), types.A), + Memoize: true, + Nondeterministic: true, + }, func(_ rego.BuiltinContext, a, b *ast.Term) (*ast.Term, error) { + slog.Debug("JQ plugin invoked") + var input map[string]any + var query string + + if err := ast.As(a.Value, &input); err != nil { + return nil, err + } else if err := ast.As(b.Value, &query); err != nil { + return nil, err + } + + res, err := ExecuteQuery(input, query) + if err != nil { + return nil, err + } + respBytes, err := json.Marshal(res) + if err != nil { + return nil, err + } + reader := bytes.NewReader(respBytes) + v, err := ast.ValueFromReader(reader) + if err != nil { + return nil, err + } + + return ast.NewTerm(v), nil + }, + ) +} + +func ExecuteQuery(inputJSON map[string]any, queryString string) ([]any, error) { + // first unescape the query string + unescapedQueryString, err := unescapeQueryString(queryString) + if err != nil { + return nil, err + } + + query, err := gojq.Parse(unescapedQueryString) + if err != nil { + return nil, err + } + iter := query.Run(inputJSON) + found := []any{} + for { + v, ok := iter.Next() + if !ok { + break + } + if err, ok2 := v.(error); ok2 { + //nolint:errorlint // temp following gojq example + if err, ok3 := err.(*gojq.HaltError); ok3 && err.Value() == nil { + break + } + // ignore error: we don't have a match but that is not an error state in this case + } else { + if v != nil { + found = append(found, v) + } + } + } + + return found, nil +} + +// unescape any strings within the provided string +func unescapeQueryString(queryString string) (string, error) { + if queryString == "" { + return "", nil + } + unquotedQueryString, err := strconv.Unquote(queryString) + if err != nil { + if err.Error() == "invalid syntax" { + slog.Debug("invalid syntax error when unquoting means there was nothing to unescape. carry on.", slog.String("queryString", queryString)) + unquotedQueryString = queryString + } else { + slog.Error("failed to unescape double quotes in subject external selector value", slog.String("queryString", queryString), slog.String("error", err.Error())) + return "", err + } + } + slog.Debug("unescaped any double quotes in jq query string", slog.String("queryString", unquotedQueryString)) + return unquotedQueryString, nil +} diff --git a/service/internal/jqbuiltin/jq_builtin_test.go b/service/internal/jqbuiltin/jq_builtin_test.go new file mode 100644 index 0000000000..1e64303b02 --- /dev/null +++ b/service/internal/jqbuiltin/jq_builtin_test.go @@ -0,0 +1,79 @@ +package jqbuiltin_test + +import ( + "testing" + + "github.com/opentdf/platform/service/internal/jqbuiltin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testResult1 string = "helloworld" + +var testInput1 = map[string]interface{}{ + "testing1": testResult1, +} +var testQuery1 = ".testing1" + +func Test_JQSuccessSimple(t *testing.T) { + res, err := jqbuiltin.ExecuteQuery(testInput1, testQuery1) + require.NoError(t, err) + assert.Equal(t, []any{testResult1}, res) +} + +var testInput2 = map[string]interface{}{ + "testing1": map[string]interface{}{"testing2": testResult1}, +} +var testQuery2 = ".testing1.testing2" + +func Test_JQSuccessTwoDeep(t *testing.T) { + res, err := jqbuiltin.ExecuteQuery(testInput2, testQuery2) + require.NoError(t, err) + assert.Equal(t, []any{testResult1}, res) +} + +var testInput3 = map[string]interface{}{ + "testing1": map[string]interface{}{"testing2": []any{testResult1}}, +} +var testQuery3 = ".testing1.testing2[0]" + +func Test_JQSuccessTwoDeepInArray(t *testing.T) { + res, err := jqbuiltin.ExecuteQuery(testInput3, testQuery3) + require.NoError(t, err) + assert.Equal(t, []any{testResult1}, res) +} + +const testResult2 string = "whatsup" + +var testInput4 = map[string]interface{}{ + "testing1": map[string]interface{}{"testing2": []any{testResult1, testResult2}}, +} +var testQuery4 = ".testing1.testing2[]" + +func Test_JQSuccessTwoDeepAllInArray(t *testing.T) { + res, err := jqbuiltin.ExecuteQuery(testInput4, testQuery4) + require.NoError(t, err) + assert.Equal(t, []any{testResult1, testResult2}, res) +} + +var testInput5 = map[string]interface{}{ + "testing1": map[string]interface{}{"testing2": testResult1}, +} +var testQuery5 = ".testing1.testing3" + +func Test_JQSuccessTwoDeepAllNoMatch(t *testing.T) { + res, err := jqbuiltin.ExecuteQuery(testInput5, testQuery5) + require.NoError(t, err) + assert.Equal(t, []any{}, res) +} + +var testInput6 = map[string]interface{}{ + "testing1": map[string]interface{}{"testing2": []any{testResult1, testResult2}}, +} +var testQuery6 = ".testing1.testing2 | index(\"" + testResult2 + "\")" + +func Test_JQSuccessUnescapeQuote(t *testing.T) { + res, err := jqbuiltin.ExecuteQuery(testInput6, testQuery6) + require.NoError(t, err) + assert.Equal(t, []any{1}, res) +} diff --git a/service/internal/opa/mock_bundle_server.go b/service/internal/opa/mock_bundle_server.go index 502c7b0ce4..a5ba3ffa82 100644 --- a/service/internal/opa/mock_bundle_server.go +++ b/service/internal/opa/mock_bundle_server.go @@ -14,7 +14,7 @@ type mockBundleServer struct { } func createMockServer() (*mockBundleServer, error) { - policy, err := policies.EntitlementsRego.ReadFile("entitlements/entitlements.rego") + policy, err := policies.EntitlementsRego.ReadFile("entitlements/entitlements-keycloak.rego") if err != nil { return nil, fmt.Errorf("failed to read entitlements policy: %w", err) } diff --git a/service/internal/opa/opa.go b/service/internal/opa/opa.go index da36174e9c..bb76c25074 100644 --- a/service/internal/opa/opa.go +++ b/service/internal/opa/opa.go @@ -11,6 +11,7 @@ import ( opalog "github.com/open-policy-agent/opa/logging" "github.com/open-policy-agent/opa/sdk" "github.com/opentdf/platform/service/internal/idpplugin" + "github.com/opentdf/platform/service/internal/jqbuiltin" ) type Engine struct { @@ -55,6 +56,7 @@ func NewEngine(config Config) (*Engine, error) { } slog.Debug("plugging in plugins") idpplugin.KeycloakBuiltins() + jqbuiltin.JQBuiltin() opa, err := sdk.New(context.Background(), sdk.Options{ Config: bytes.NewReader(bConfig), Logger: &logger, diff --git a/service/policies/entitlements/entitlements-keycloak.rego b/service/policies/entitlements/entitlements-keycloak.rego index 45f21427db..4001ff79ea 100644 --- a/service/policies/entitlements/entitlements-keycloak.rego +++ b/service/policies/entitlements/entitlements-keycloak.rego @@ -10,11 +10,19 @@ import rego.v1 "legacykeycloak": input.idp.legacy, }} +# proto oneof only allows for one of the fields in the entity struct idp_request := {"entities": [{ "id": input.entity.id, - "emailAddress": input.entity.email_address, "clientId": input.entity.client_id, -}]} +}]} if { input.entity.client_id } +else := {"entities": [{ + "id": input.entity.id, + "emailAddress": input.entity.email_address, +}]} if { input.entity.email_address } +else := {"entities": [{ + "id": input.entity.id, + "userName": input.entity.username, +}]} if { input.entity.username } attributes := [attribute | # external entity @@ -34,13 +42,22 @@ condition_group_evaluate(payload, boolean_operator, conditions) if { # AND boolean_operator == 1 some condition in conditions - condition_evaluate(payload[condition.subject_external_selector_value], condition.operator, condition.subject_external_values) + # TODO: additional_props is a list of entity representations + # (for when an email provided is for a group) + # how do we handle the situation when multiple entities returned + # add to the list for each entity? + # or do they all have to have the attribtue for it to be returned? + condition_evaluate(jq.evaluate(payload[0], condition.subject_external_selector_value), + condition.operator, condition.subject_external_values + ) } else if { # OR boolean_operator == 2 payload[key] some condition in conditions - condition_evaluate(payload[condition.subject_external_selector_value], condition.operator, condition.subject_external_values) + condition_evaluate(jq.evaluate(payload[0], condition.subject_external_selector_value), + condition.operator, condition.subject_external_values + ) } # condition diff --git a/service/policies/entitlements/entitlements.rego b/service/policies/entitlements/entitlements.rego index 88b7273f8b..2084cef9a9 100644 --- a/service/policies/entitlements/entitlements.rego +++ b/service/policies/entitlements/entitlements.rego @@ -21,13 +21,13 @@ condition_group_evaluate(payload, boolean_operator, conditions) if { # AND boolean_operator == 1 some condition in conditions - condition_evaluate(payload[condition.subject_external_selector_value], condition.operator, condition.subject_external_values) + condition_evaluate(jq.evaluate(payload, condition.subject_external_selector_value), condition.operator, condition.subject_external_values) } else if { # OR boolean_operator == 2 payload[key] some condition in conditions - condition_evaluate(payload[condition.subject_external_selector_value], condition.operator, condition.subject_external_values) + condition_evaluate(jq.evaluate(payload, condition.subject_external_selector_value), condition.operator, condition.subject_external_values) } # condition