Skip to content

Commit

Permalink
fix(vault): Hide ACL flags when not required (#7701)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajeetdsouza authored Apr 8, 2021
1 parent a77bbe8 commit 8046aff
Show file tree
Hide file tree
Showing 17 changed files with 191 additions and 207 deletions.
15 changes: 2 additions & 13 deletions dgraph/cmd/alpha/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ they form a Raft group and provide synchronous replication.
// --tls SuperFlag
x.RegisterServerTLSFlags(flag)
// --encryption and --vault Superflag
enc.RegisterFlags(flag)
ee.RegisterAclAndEncFlags(flag)

flag.StringP("postings", "p", "p", "Directory to store posting lists.")
flag.String("tmp", "t", "Directory to store temporary buffers.")
Expand Down Expand Up @@ -181,17 +181,6 @@ they form a Raft group and provide synchronous replication.
`internal").`).
String())

flag.String("acl", worker.AclDefaults, z.NewSuperFlagHelp(worker.AclDefaults).
Head("[Enterprise Feature] ACL options").
Flag("secret-file",
"The file that stores the HMAC secret, which is used for signing the JWT and "+
"should have at least 32 ASCII characters. Required to enable ACLs.").
Flag("access-ttl",
"The TTL for the access JWT.").
Flag("refresh-ttl",
"The TTL for the refresh JWT.").
String())

flag.String("limit", worker.LimitDefaults, z.NewSuperFlagHelp(worker.LimitDefaults).
Head("Limit options").
Flag("query-edge",
Expand Down Expand Up @@ -660,7 +649,7 @@ func run() {
if aclKey != nil {
opts.HmacSecret = aclKey

acl := z.NewSuperFlag(Alpha.Conf.GetString("acl")).MergeAndCheckDefault(worker.AclDefaults)
acl := z.NewSuperFlag(Alpha.Conf.GetString("acl")).MergeAndCheckDefault(ee.AclDefaults)
opts.AccessJwtTtl = acl.GetDuration("access-ttl")
opts.RefreshJwtTtl = acl.GetDuration("refresh-ttl")

Expand Down
3 changes: 1 addition & 2 deletions dgraph/cmd/bulk/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"github.com/dgraph-io/dgraph/worker"
"github.com/dgraph-io/ristretto/z"

"github.com/dgraph-io/dgraph/ee/enc"
"github.com/dgraph-io/dgraph/tok"
"github.com/dgraph-io/dgraph/x"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -133,7 +132,7 @@ func init() {

x.RegisterClientTLSFlags(flag)
// Encryption and Vault options
enc.RegisterFlags(flag)
ee.RegisterEncFlag(flag)
}

func run() {
Expand Down
3 changes: 1 addition & 2 deletions dgraph/cmd/debug/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import (

"github.com/dgraph-io/dgraph/codec"
"github.com/dgraph-io/dgraph/ee"
"github.com/dgraph-io/dgraph/ee/enc"
"github.com/dgraph-io/dgraph/posting"
"github.com/dgraph-io/dgraph/protos/pb"
"github.com/dgraph-io/dgraph/raftwal"
Expand Down Expand Up @@ -109,7 +108,7 @@ func init() {
flag.StringVarP(&opt.wsetSnapshot, "snap", "s", "",
"Set snapshot term,index,readts to this. Value must be comma-separated list containing"+
" the value for these vars in that order.")
enc.RegisterFlags(flag)
ee.RegisterEncFlag(flag)
}

func toInt(o *pb.Posting) int {
Expand Down
2 changes: 1 addition & 1 deletion dgraph/cmd/decrypt/decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func init() {
flag := Decrypt.Cmd.Flags()
flag.StringP("file", "f", "", "Path to file to decrypt.")
flag.StringP("out", "o", "", "Path to the decrypted file.")
enc.RegisterFlags(flag)
ee.RegisterEncFlag(flag)
}
func run() {
opts := options{
Expand Down
4 changes: 2 additions & 2 deletions dgraph/cmd/live/load-uids/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/dgraph-io/dgraph/ee/enc"
"github.com/dgraph-io/dgraph/ee"
"github.com/dgraph-io/dgraph/testutil"
"github.com/dgraph-io/dgraph/x"
)
Expand Down Expand Up @@ -286,7 +286,7 @@ func TestLiveLoadExportedSchema(t *testing.T) {
"--schema", localExportPath + "/" + exportId + "/" + groupId + ".schema.gz",
"--files", localExportPath + "/" + exportId + "/" + groupId + ".rdf.gz",
"--encryption",
enc.BuildEncFlag(testDataDir + "/../../../../ee/enc/test-fixtures/enc-key"),
ee.BuildEncFlag(testDataDir + "/../../../../ee/enc/test-fixtures/enc-key"),
"--alpha", alphaService, "--zero", zeroService,
"--creds", "user=groot;password=password;"},
}
Expand Down
2 changes: 1 addition & 1 deletion dgraph/cmd/live/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func init() {

flag := Live.Cmd.Flags()
// --vault SuperFlag and encryption flags
enc.RegisterFlags(flag)
ee.RegisterEncFlag(flag)
// --tls SuperFlag
x.RegisterClientTLSFlags(flag)

Expand Down
5 changes: 2 additions & 3 deletions ee/backup/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import (
"google.golang.org/grpc/credentials"

"github.com/dgraph-io/dgraph/ee"
"github.com/dgraph-io/dgraph/ee/enc"
"github.com/dgraph-io/dgraph/protos/pb"
"github.com/dgraph-io/dgraph/upgrade"
"github.com/dgraph-io/dgraph/worker"
Expand Down Expand Up @@ -152,7 +151,7 @@ $ dgraph restore -p . -l /var/backups/dgraph -z localhost:5080
"update the timestamp and max uid when you start the cluster. The correct values are "+
"printed near the end of this command's output.")
x.RegisterClientTLSFlags(flag)
enc.RegisterFlags(flag)
ee.RegisterEncFlag(flag)
_ = Restore.Cmd.MarkFlagRequired("postings")
_ = Restore.Cmd.MarkFlagRequired("location")
}
Expand Down Expand Up @@ -342,7 +341,7 @@ func initExportBackup() {
`If true, retrieve the CORS from DB and append at the end of GraphQL schema.
It also deletes the deprecated types and predicates.
Use this option when exporting a backup of 20.11 for loading onto 21.03.`)
enc.RegisterFlags(flag)
ee.RegisterEncFlag(flag)
}

func runExportBackup() error {
Expand Down
43 changes: 0 additions & 43 deletions ee/enc/flags.go

This file was deleted.

139 changes: 139 additions & 0 deletions ee/flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright 2021 Dgraph Labs, Inc. and Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package ee

import (
"fmt"
"strings"

"github.com/dgraph-io/ristretto/z"
"github.com/spf13/pflag"
)

const (
flagAcl = "acl"
flagAclAccessTtl = "access-ttl"
flagAclRefreshTtl = "refresh-ttl"
flagAclSecretFile = "secret-file"

flagEnc = "encryption"
flagEncKeyFile = "key-file"

flagVault = "vault"
flagVaultAddr = "addr"
flagVaultRoleIdFile = "role-id-file"
flagVaultSecretIdFile = "secret-id-file"
flagVaultPath = "path"
flagVaultAclField = "acl-field"
flagVaultAclFormat = "acl-format"
flagVaultEncField = "enc-field"
flagVaultEncFormat = "enc-format"
)

func RegisterAclAndEncFlags(flag *pflag.FlagSet) {
registerAclFlag(flag)
registerEncFlag(flag)
registerVaultFlag(flag, true, true)
}

func RegisterEncFlag(flag *pflag.FlagSet) {
registerEncFlag(flag)
registerVaultFlag(flag, false, true)
}

var (
AclDefaults = fmt.Sprintf("%s=%s; %s=%s; %s=%s",
flagAclAccessTtl, "6h",
flagAclRefreshTtl, "30d",
flagAclSecretFile, "")
encDefaults = fmt.Sprintf("%s=%s", flagEncKeyFile, "")
)

func vaultDefaults(aclEnabled, encEnabled bool) string {
var configBuilder strings.Builder
fmt.Fprintf(&configBuilder, "%s=%s; %s=%s; %s=%s; %s=%s",
flagVaultAddr, "http://localhost:8200",
flagVaultRoleIdFile, "",
flagVaultSecretIdFile, "",
flagVaultPath, "secret/data/dgraph")
if aclEnabled {
fmt.Fprintf(&configBuilder, "; %s=%s; %s=%s",
flagVaultAclField, "",
flagVaultAclFormat, "base64")
}
if encEnabled {
fmt.Fprintf(&configBuilder, "; %s=%s; %s=%s",
flagVaultEncField, "",
flagVaultEncFormat, "base64")
}
return configBuilder.String()
}

func registerVaultFlag(flag *pflag.FlagSet, aclEnabled, encEnabled bool) {
// Generate default configuration.
config := vaultDefaults(aclEnabled, encEnabled)

// Generate help text.
helpBuilder := z.NewSuperFlagHelp(config).
Head("Vault options").
Flag(flagVaultAddr, "Vault server address (format: http://ip:port).").
Flag(flagVaultRoleIdFile, "Vault RoleID file, used for AppRole authentication.").
Flag(flagVaultSecretIdFile, "Vault SecretID file, used for AppRole authentication.").
Flag(flagVaultPath, "Vault KV store path (e.g. 'secret/data/dgraph' for KV V2, "+
"'kv/dgraph' for KV V1).")
if aclEnabled {
helpBuilder = helpBuilder.
Flag(flagVaultAclField, "Vault field containing ACL key.").
Flag(flagVaultAclFormat, "ACL key format, can be 'raw' or 'base64'.")
}
if encEnabled {
helpBuilder = helpBuilder.
Flag(flagVaultEncField, "Vault field containing encryption key.").
Flag(flagVaultEncFormat, "Encryption key format, can be 'raw' or 'base64'.")
}
helpText := helpBuilder.String()

// Register flag.
flag.String(flagVault, config, helpText)
}

func registerAclFlag(flag *pflag.FlagSet) {
helpText := z.NewSuperFlagHelp(AclDefaults).
Head("[Enterprise Feature] ACL options").
Flag("secret-file",
"The file that stores the HMAC secret, which is used for signing the JWT and "+
"should have at least 32 ASCII characters. Required to enable ACLs.").
Flag("access-ttl",
"The TTL for the access JWT.").
Flag("refresh-ttl",
"The TTL for the refresh JWT.").
String()
flag.String(flagAcl, AclDefaults, helpText)
}

func registerEncFlag(flag *pflag.FlagSet) {
helpText := z.NewSuperFlagHelp(encDefaults).
Head("[Enterprise Feature] Encryption At Rest options").
Flag("key-file", "The file that stores the symmetric key of length 16, 24, or 32 bytes."+
"The key size determines the chosen AES cipher (AES-128, AES-192, and AES-256 respectively).").
String()
flag.String(flagEnc, encDefaults, helpText)
}

func BuildEncFlag(filename string) string {
return fmt.Sprintf("key-file=%s;", filename)
}
6 changes: 2 additions & 4 deletions ee/utils_ee.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ package ee
import (
"io/ioutil"

"github.com/dgraph-io/dgraph/ee/enc"
"github.com/dgraph-io/dgraph/ee/vault"
"github.com/dgraph-io/dgraph/x"
"github.com/dgraph-io/ristretto/z"
"github.com/golang/glog"
Expand All @@ -28,7 +26,7 @@ import (
// this function exits with an error.
func GetKeys(config *viper.Viper) (x.SensitiveByteSlice, x.SensitiveByteSlice) {
aclSuperFlag := z.NewSuperFlag(config.GetString("acl"))
aclKey, encKey := vault.GetKeys(config)
aclKey, encKey := vaultGetKeys(config)
var err error

aclKeyFile := aclSuperFlag.GetPath("secret-file")
Expand All @@ -44,7 +42,7 @@ func GetKeys(config *viper.Viper) (x.SensitiveByteSlice, x.SensitiveByteSlice) {
glog.Exitf("ACL secret key must have length of at least 32 bytes, got %d bytes instead", l)
}

encSuperFlag := z.NewSuperFlag(config.GetString("encryption")).MergeAndCheckDefault(enc.EncryptionDefaults)
encSuperFlag := z.NewSuperFlag(config.GetString("encryption")).MergeAndCheckDefault(encDefaults)
encKeyFile := encSuperFlag.GetPath("key-file")
if encKeyFile != "" {
if encKey != nil {
Expand Down
Loading

0 comments on commit 8046aff

Please sign in to comment.