Skip to content

Commit 144685a

Browse files
fix(Query): Fix val queries when ACL is enabled #5945
1 parent d27ca2d commit 144685a

File tree

2 files changed

+281
-6
lines changed

2 files changed

+281
-6
lines changed

edgraph/access_ee.go

+78-6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ import (
4040
"google.golang.org/grpc/status"
4141
)
4242

43+
type predsAndvars struct {
44+
preds []string
45+
vars map[string]string
46+
}
47+
4348
// Login handles login requests from clients.
4449
func (s *Server) Login(ctx context.Context,
4550
request *api.LoginRequest) (*api.Response, error) {
@@ -682,14 +687,19 @@ func authorizeMutation(ctx context.Context, gmu *gql.Mutation) error {
682687
return err
683688
}
684689

685-
func parsePredsFromQuery(gqls []*gql.GraphQuery) []string {
690+
func parsePredsFromQuery(gqls []*gql.GraphQuery) predsAndvars {
686691
predsMap := make(map[string]struct{})
692+
varsMap := make(map[string]string)
687693
for _, gq := range gqls {
688694
if gq.Func != nil {
689695
predsMap[gq.Func.Attr] = struct{}{}
690696
}
691-
if len(gq.Attr) > 0 && gq.Attr != "uid" && gq.Attr != "expand" {
697+
if len(gq.Var) > 0 {
698+
varsMap[gq.Var] = gq.Attr
699+
}
700+
if len(gq.Attr) > 0 && gq.Attr != "uid" && gq.Attr != "expand" && gq.Attr != "val" {
692701
predsMap[gq.Attr] = struct{}{}
702+
693703
}
694704
for _, ord := range gq.Order {
695705
predsMap[ord.Attr] = struct{}{}
@@ -700,15 +710,23 @@ func parsePredsFromQuery(gqls []*gql.GraphQuery) []string {
700710
for _, pred := range parsePredsFromFilter(gq.Filter) {
701711
predsMap[pred] = struct{}{}
702712
}
703-
for _, childPred := range parsePredsFromQuery(gq.Children) {
713+
childPredandVars := parsePredsFromQuery(gq.Children)
714+
for _, childPred := range childPredandVars.preds {
704715
predsMap[childPred] = struct{}{}
705716
}
717+
for childVar := range childPredandVars.vars {
718+
varsMap[childVar] = childPredandVars.vars[childVar]
719+
}
706720
}
707721
preds := make([]string, 0, len(predsMap))
708722
for pred := range predsMap {
709-
preds = append(preds, pred)
723+
if _, found := varsMap[pred]; !found {
724+
preds = append(preds, pred)
725+
}
710726
}
711-
return preds
727+
728+
pv := predsAndvars{preds: preds, vars: varsMap}
729+
return pv
712730
}
713731

714732
func parsePredsFromFilter(f *gql.FilterTree) []string {
@@ -755,7 +773,16 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er
755773

756774
var userId string
757775
var groupIds []string
758-
preds := parsePredsFromQuery(parsedReq.Query)
776+
predsAndvars := parsePredsFromQuery(parsedReq.Query)
777+
preds := predsAndvars.preds
778+
varsToPredMap := predsAndvars.vars
779+
780+
// Need this to efficiently identify blocked variables from the
781+
// list of blocked predicates
782+
predToVarsMap := make(map[string]string)
783+
for k, v := range varsToPredMap {
784+
predToVarsMap[v] = k
785+
}
759786

760787
doAuthorizeQuery := func() (map[string]struct{}, []string, error) {
761788
userData, err := extractUserAndGroups(ctx)
@@ -806,7 +833,18 @@ func authorizeQuery(ctx context.Context, parsedReq *gql.Result, graphql bool) er
806833
// In query context ~predicate and predicate are considered different.
807834
delete(blockedPreds, "~dgraph.user.group")
808835
}
836+
837+
blockedVars := make(map[string]struct{})
838+
for predicate := range blockedPreds {
839+
if variable, found := predToVarsMap[predicate]; found {
840+
// Add variables to blockedPreds to delete from Query
841+
blockedPreds[variable] = struct{}{}
842+
// Collect blocked Variables to remove from QueryVars
843+
blockedVars[variable] = struct{}{}
844+
}
845+
}
809846
parsedReq.Query = removePredsFromQuery(parsedReq.Query, blockedPreds)
847+
parsedReq.QueryVars = removeVarsFromQueryVars(parsedReq.QueryVars, blockedVars)
810848
}
811849
for i := range parsedReq.Query {
812850
parsedReq.Query[i].AllowedPreds = allowedPreds
@@ -1057,6 +1095,7 @@ func removePredsFromQuery(gqs []*gql.GraphQuery,
10571095
blockedPreds map[string]struct{}) []*gql.GraphQuery {
10581096

10591097
filteredGQs := gqs[:0]
1098+
L:
10601099
for _, gq := range gqs {
10611100
if gq.Func != nil && len(gq.Func.Attr) > 0 {
10621101
if _, ok := blockedPreds[gq.Func.Attr]; ok {
@@ -1067,6 +1106,15 @@ func removePredsFromQuery(gqs []*gql.GraphQuery,
10671106
if _, ok := blockedPreds[gq.Attr]; ok {
10681107
continue
10691108
}
1109+
if gq.Attr == "val" {
1110+
// TODO (Anurag): If val supports multiple variables, this would
1111+
// need an upgrade
1112+
for _, variable := range gq.NeedsVar {
1113+
if _, ok := blockedPreds[variable.Name]; ok {
1114+
continue L
1115+
}
1116+
}
1117+
}
10701118
}
10711119

10721120
order := gq.Order[:0]
@@ -1087,6 +1135,30 @@ func removePredsFromQuery(gqs []*gql.GraphQuery,
10871135
return filteredGQs
10881136
}
10891137

1138+
func removeVarsFromQueryVars(gqs []*gql.Vars,
1139+
blockedVars map[string]struct{}) []*gql.Vars {
1140+
1141+
filteredGQs := gqs[:0]
1142+
for _, gq := range gqs {
1143+
var defines []string
1144+
var needs []string
1145+
for _, variable := range gq.Defines {
1146+
if _, ok := blockedVars[variable]; !ok {
1147+
defines = append(defines, variable)
1148+
}
1149+
}
1150+
for _, variable := range gq.Needs {
1151+
if _, ok := blockedVars[variable]; !ok {
1152+
needs = append(needs, variable)
1153+
}
1154+
}
1155+
gq.Defines = defines
1156+
gq.Needs = needs
1157+
filteredGQs = append(filteredGQs, gq)
1158+
}
1159+
return filteredGQs
1160+
}
1161+
10901162
func removeFilters(f *gql.FilterTree, blockedPreds map[string]struct{}) *gql.FilterTree {
10911163
if f == nil {
10921164
return nil

ee/acl/acl_test.go

+203
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,209 @@ func TestExpandQueryWithACLPermissions(t *testing.T) {
12541254
testutil.CompareJSON(t, `{"me":[{"name":"RandomGuy","age":23, "nickname":"RG"},{"name":"RandomGuy2","age":25, "nickname":"RG2"}]}`,
12551255
string(resp.GetJson()))
12561256

1257+
}
1258+
1259+
func TestValQueryWithACLPermissions(t *testing.T) {
1260+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second)
1261+
defer cancel()
1262+
dg, err := testutil.DgraphClientWithGroot(testutil.SockAddr)
1263+
require.NoError(t, err)
1264+
1265+
testutil.DropAll(t, dg)
1266+
1267+
op := api.Operation{Schema: `
1268+
name : string @index(exact) .
1269+
nickname : string @index(exact) .
1270+
age : int .
1271+
type TypeName {
1272+
name: string
1273+
nickname: string
1274+
age: int
1275+
}
1276+
`}
1277+
require.NoError(t, dg.Alter(ctx, &op))
1278+
1279+
resetUser(t)
1280+
1281+
accessJwt, _, err := testutil.HttpLogin(&testutil.LoginParams{
1282+
Endpoint: adminEndpoint,
1283+
UserID: "groot",
1284+
Passwd: "password",
1285+
})
1286+
require.NoError(t, err, "login failed")
1287+
1288+
createGroup(t, accessJwt, devGroup)
1289+
// createGroup(t, accessJwt, sreGroup)
1290+
1291+
// addRulesToGroup(t, accessJwt, sreGroup, []rule{{"age", Read.Code}, {"name", Write.Code}})
1292+
addToGroup(t, accessJwt, userid, devGroup)
1293+
1294+
txn := dg.NewTxn()
1295+
mutation := &api.Mutation{
1296+
SetNquads: []byte(`
1297+
_:a <name> "RandomGuy" .
1298+
_:a <age> "23" .
1299+
_:a <nickname> "RG" .
1300+
_:a <dgraph.type> "TypeName" .
1301+
_:b <name> "RandomGuy2" .
1302+
_:b <age> "25" .
1303+
_:b <nickname> "RG2" .
1304+
_:b <dgraph.type> "TypeName" .
1305+
`),
1306+
CommitNow: true,
1307+
}
1308+
_, err = txn.Mutate(ctx, mutation)
1309+
require.NoError(t, err)
1310+
1311+
query := `{q1(func: has(name)){
1312+
v as name
1313+
a as age
1314+
}
1315+
q2(func: eq(val(v), "RandomGuy")) {
1316+
val(v)
1317+
val(a)
1318+
}}`
1319+
1320+
// Test that groot has access to all the predicates
1321+
resp, err := dg.NewReadOnlyTxn().Query(ctx, query)
1322+
require.NoError(t, err, "Error while querying data")
1323+
testutil.CompareJSON(t, `{"q1":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}],"q2":[{"val(v)":"RandomGuy","val(a)":23}]}`,
1324+
string(resp.GetJson()))
1325+
1326+
// All test cases
1327+
tests := []struct {
1328+
input string
1329+
descriptionNoPerm string
1330+
outputNoPerm string
1331+
descriptionNamePerm string
1332+
outputNamePerm string
1333+
descriptionNameAgePerm string
1334+
outputNameAgePerm string
1335+
}{
1336+
{
1337+
`
1338+
{
1339+
q1(func: has(name)) {
1340+
v as name
1341+
a as age
1342+
}
1343+
q2(func: eq(val(v), "RandomGuy")) {
1344+
val(v)
1345+
val(a)
1346+
}
1347+
}
1348+
`,
1349+
"alice doesn't have access to name or age",
1350+
`{}`,
1351+
1352+
`alice has access to name`,
1353+
`{"q1":[{"name":"RandomGuy"},{"name":"RandomGuy2"}],"q2":[{"val(v)":"RandomGuy"}]}`,
1354+
1355+
"alice has access to name and age",
1356+
`{"q1":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}],"q2":[{"val(v)":"RandomGuy","val(a)":23}]}`,
1357+
},
1358+
{
1359+
`{
1360+
q1(func: has(name) ) {
1361+
a as age
1362+
}
1363+
q2(func: has(name) ) {
1364+
val(a)
1365+
}
1366+
}`,
1367+
"alice doesn't have access to name or age",
1368+
`{}`,
1369+
1370+
`alice has access to name`,
1371+
`{"q1":[],"q2":[]}`,
1372+
1373+
"alice has access to name and age",
1374+
`{"q1":[{"age":23},{"age":25}],"q2":[{"val(a)":23},{"val(a)":25}]}`,
1375+
},
1376+
{
1377+
`{
1378+
f as q1(func: has(name) ) {
1379+
n as name
1380+
a as age
1381+
}
1382+
q2(func: uid(f), orderdesc: val(a) ) {
1383+
name
1384+
val(n)
1385+
val(a)
1386+
}
1387+
}`,
1388+
"alice doesn't have access to name or age",
1389+
`{"q2":[]}`,
1390+
1391+
`alice has access to name`,
1392+
`{"q1":[{"name":"RandomGuy"},{"name":"RandomGuy2"}],
1393+
"q2":[{"name":"RandomGuy","val(n)":"RandomGuy"},{"name":"RandomGuy2","val(n)":"RandomGuy2"}]}`,
1394+
1395+
"alice has access to name and age",
1396+
`{"q1":[{"name":"RandomGuy","age":23},{"name":"RandomGuy2","age":25}],
1397+
"q2":[{"name":"RandomGuy2","val(n)":"RandomGuy2","val(a)":25},{"name":"RandomGuy","val(n)":"RandomGuy","val(a)":23}]}`,
1398+
},
1399+
}
1400+
1401+
userClient, err := testutil.DgraphClient(testutil.SockAddr)
1402+
require.NoError(t, err)
1403+
time.Sleep(6 * time.Second)
1404+
1405+
err = userClient.Login(ctx, userid, userpassword)
1406+
require.NoError(t, err)
1407+
1408+
// Query via user when user has no permissions
1409+
for _, tc := range tests {
1410+
desc := tc.descriptionNoPerm
1411+
t.Run(desc, func(t *testing.T) {
1412+
resp, err := userClient.NewTxn().Query(ctx, tc.input)
1413+
require.NoError(t, err)
1414+
testutil.CompareJSON(t, tc.outputNoPerm, string(resp.Json))
1415+
})
1416+
}
1417+
1418+
// Login to groot to modify accesses (1)
1419+
accessJwt, _, err = testutil.HttpLogin(&testutil.LoginParams{
1420+
Endpoint: adminEndpoint,
1421+
UserID: "groot",
1422+
Passwd: "password",
1423+
})
1424+
require.NoError(t, err, "login failed")
1425+
1426+
// Give read access of <name> to dev
1427+
addRulesToGroup(t, accessJwt, devGroup, []rule{{"name", Read.Code}})
1428+
time.Sleep(6 * time.Second)
1429+
1430+
for _, tc := range tests {
1431+
desc := tc.descriptionNamePerm
1432+
t.Run(desc, func(t *testing.T) {
1433+
resp, err := userClient.NewTxn().Query(ctx, tc.input)
1434+
require.NoError(t, err)
1435+
testutil.CompareJSON(t, tc.outputNamePerm, string(resp.Json))
1436+
})
1437+
}
1438+
1439+
// Login to groot to modify accesses (1)
1440+
accessJwt, _, err = testutil.HttpLogin(&testutil.LoginParams{
1441+
Endpoint: adminEndpoint,
1442+
UserID: "groot",
1443+
Passwd: "password",
1444+
})
1445+
require.NoError(t, err, "login failed")
1446+
1447+
// Give read access of <name> and <age> to dev
1448+
addRulesToGroup(t, accessJwt, devGroup, []rule{{"name", Read.Code}, {"age", Read.Code}})
1449+
time.Sleep(6 * time.Second)
1450+
1451+
for _, tc := range tests {
1452+
desc := tc.descriptionNameAgePerm
1453+
t.Run(desc, func(t *testing.T) {
1454+
resp, err := userClient.NewTxn().Query(ctx, tc.input)
1455+
require.NoError(t, err)
1456+
testutil.CompareJSON(t, tc.outputNameAgePerm, string(resp.Json))
1457+
})
1458+
}
1459+
12571460
}
12581461
func TestNewACLPredicates(t *testing.T) {
12591462
ctx, _ := context.WithTimeout(context.Background(), 100*time.Second)

0 commit comments

Comments
 (0)