diff --git a/ee/enc/util_ee_test.go b/ee/enc/util_ee_test.go index 9d98462dfdc..92ed0434d53 100644 --- a/ee/enc/util_ee_test.go +++ b/ee/enc/util_ee_test.go @@ -41,8 +41,9 @@ func resetConfig(config *viper.Viper) { config.Set(vaultAddr, "http://localhost:8200") config.Set(vaultRoleIDFile, "") config.Set(vaultSecretIDFile, "") - config.Set(vaultPath, "dgraph") + config.Set(vaultPath, "secret/data/dgraph") config.Set(vaultField, "enc_key") + config.Set(vaultFormat, "base64") } // TODO: The function below allows instantiating a real Vault server. But results in go.mod issues. @@ -117,6 +118,15 @@ func TestNewKeyReader(t *testing.T) { require.Nil(t, k) require.Error(t, err) + // Bad vault_format. Must be raw or base64. + resetConfig(config) + config.Set(vaultRoleIDFile, "./test-fixtures/dummy_role_id_file") + config.Set(vaultSecretIDFile, "./test-fixtures/dummy_secret_id_file") + config.Set(vaultFormat, "foo") // error. + kr, err = newKeyReader(config) + require.Error(t, err) + require.Nil(t, kr) + // RoleID and SecretID given but RoleID file and SecretID file exists and is valid. resetConfig(config) //nl, _ := startVaultServer(t, "dgraph", "enc_key", "1234567890123456") diff --git a/ee/enc/vault_ee.go b/ee/enc/vault_ee.go index 35d000c2106..63d282e176f 100644 --- a/ee/enc/vault_ee.go +++ b/ee/enc/vault_ee.go @@ -19,9 +19,11 @@ package enc import ( + "encoding/base64" "io/ioutil" "github.com/dgraph-io/dgraph/x" + "github.com/golang/glog" "github.com/hashicorp/vault/api" "github.com/pkg/errors" "github.com/spf13/pflag" @@ -35,6 +37,7 @@ const ( vaultSecretIDFile = "vault_secretid_file" vaultPath = "vault_path" vaultField = "vault_field" + vaultFormat = "vault_format" ) // RegisterVaultFlags registers the required flags to integrate with Vault. @@ -46,10 +49,12 @@ func registerVaultFlags(flag *pflag.FlagSet) { "File containing Vault role-id used for approle auth.") flag.String(vaultSecretIDFile, "", "File containing Vault secret-id used for approle auth.") - flag.String(vaultPath, "dgraph", - "Vault kv store path.") + flag.String(vaultPath, "secret/data/dgraph", + "Vault kv store path. e.g. secret/data/dgraph for kv-v2, kv/dgraph for kv-v1.") flag.String(vaultField, "enc_key", - "Vault kv store field whose value is the encryption key.") + "Vault kv store field whose value is the Base64 encoded encryption key.") + flag.String(vaultFormat, "base64", + "Vault field format. raw or base64") } // vaultKeyReader implements the KeyReader interface. It reads the key from vault server. @@ -59,6 +64,7 @@ type vaultKeyReader struct { secretID string path string field string + format string } func newVaultKeyReader(cfg *viper.Viper) (*vaultKeyReader, error) { @@ -68,10 +74,15 @@ func newVaultKeyReader(cfg *viper.Viper) (*vaultKeyReader, error) { secretID: cfg.GetString(vaultSecretIDFile), path: cfg.GetString(vaultPath), field: cfg.GetString(vaultField), + format: cfg.GetString(vaultFormat), } - if v.addr == "" || v.path == "" || v.field == "" { - return nil, errors.Errorf("%v, %v or %v is missing", vaultAddr, vaultPath, vaultField) + if v.addr == "" || v.path == "" || v.field == "" || v.format == "" { + return nil, errors.Errorf("%v, %v, %v or %v is missing", + vaultAddr, vaultPath, vaultField, vaultFormat) + } + if v.format != "base64" && v.format != "raw" { + return nil, errors.Errorf("vault_format = %v; must be one of base64 or raw", v.format) } if v.roleID != "" && v.secretID != "" { @@ -120,25 +131,32 @@ func (vkr *vaultKeyReader) readKey() (x.SensitiveByteSlice, error) { } client.SetToken(resp.Auth.ClientToken) - // Read from KV store - secret, err := client.Logical().Read("secret/data/" + vkr.path) + // Read from KV store. The given path must be v1 or v2 format. We use it as is. + secret, err := client.Logical().Read(vkr.path) if err != nil || secret == nil { return nil, errors.Errorf("error or nil secret on reading key at %v: "+ "err %v", vkr.path, err) } // Parse key from response + var m map[string]interface{} m, ok := secret.Data["data"].(map[string]interface{}) if !ok { - return nil, errors.Errorf("kv store read response from vault is bad") + glog.Infof("Unable to extract key from kv v2 response. Trying kv v1.") + m = secret.Data } kVal, ok := m[vkr.field] if !ok { return nil, errors.Errorf("secret key not found at %v", vkr.field) } kbyte := []byte(kVal.(string)) - - // Validate key length suitable for AES + if vkr.format == "base64" { + kbyte, err = base64.StdEncoding.DecodeString(kVal.(string)) + if err != nil { + return nil, errors.Errorf("Unable to decode the Base64 Encoded key: err %v", err) + } + } + // Validate key length suitable for AES. klen := len(kbyte) if klen != 16 && klen != 32 && klen != 64 { return nil, errors.Errorf("bad key length %v from vault", klen)