Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow non-strings to be used to set ttl field in generic. #2699

Merged
merged 2 commits into from
May 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions vault/logical_passthrough.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"fmt"
"strings"

"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/helper/jsonutil"
"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/logical"
"github.com/hashicorp/vault/logical/framework"
)
Expand Down Expand Up @@ -126,14 +126,13 @@ func (b *PassthroughBackend) handleRead(
}

// Check if there is a ttl key
var ttl string
ttl, _ = rawData["ttl"].(string)
if len(ttl) == 0 {
ttl, _ = rawData["lease"].(string)
}
ttlDuration := b.System().DefaultLeaseTTL()
if len(ttl) != 0 {
dur, err := parseutil.ParseDurationSecond(ttl)
ttlRaw, ok := rawData["ttl"]
if !ok {
ttlRaw, ok = rawData["lease"]
}
if ok {
dur, err := parseutil.ParseDurationSecond(ttlRaw)
if err == nil {
ttlDuration = dur
}
Expand Down
47 changes: 38 additions & 9 deletions vault/logical_passthrough_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package vault

import (
"encoding/json"
"reflect"
"testing"
"time"

"github.com/hashicorp/vault/helper/parseutil"
"github.com/hashicorp/vault/logical"
)

Expand Down Expand Up @@ -49,10 +51,19 @@ func TestPassthroughBackend_Write(t *testing.T) {
}

func TestPassthroughBackend_Read(t *testing.T) {
test := func(b logical.Backend, ttlType string, leased bool) {
test := func(b logical.Backend, ttlType string, ttl interface{}, leased bool) {
req := logical.TestRequest(t, logical.UpdateOperation, "foo")
req.Data["raw"] = "test"
req.Data[ttlType] = "1h"
var reqTTL interface{}
switch ttl.(type) {
case int64:
reqTTL = ttl.(int64)
case string:
reqTTL = ttl.(string)
default:
t.Fatal("unknown ttl type")
}
req.Data[ttlType] = reqTTL
storage := req.Storage

if _, err := b.HandleRequest(req); err != nil {
Expand All @@ -67,16 +78,34 @@ func TestPassthroughBackend_Read(t *testing.T) {
t.Fatalf("err: %v", err)
}

expectedTTL, err := parseutil.ParseDurationSecond(ttl)
if err != nil {
t.Fatal(err)
}

// What comes back if an int is passed in is a json.Number which is
// actually aliased as a string so to make the deep equal happy if it's
// actually a number we set it to an int64
var respTTL interface{} = resp.Data[ttlType]
_, ok := respTTL.(json.Number)
if ok {
respTTL, err = respTTL.(json.Number).Int64()
if err != nil {
t.Fatal(err)
}
resp.Data[ttlType] = respTTL
}

expected := &logical.Response{
Secret: &logical.Secret{
LeaseOptions: logical.LeaseOptions{
Renewable: true,
TTL: time.Hour,
TTL: expectedTTL,
},
},
Data: map[string]interface{}{
"raw": "test",
ttlType: "1h",
ttlType: reqTTL,
},
}

Expand All @@ -86,15 +115,15 @@ func TestPassthroughBackend_Read(t *testing.T) {
resp.Secret.InternalData = nil
resp.Secret.LeaseID = ""
if !reflect.DeepEqual(resp, expected) {
t.Fatalf("bad response.\n\nexpected: %#v\n\nGot: %#v", expected, resp)
t.Fatalf("bad response.\n\nexpected:\n%#v\n\nGot:\n%#v", expected, resp)
}
}
b := testPassthroughLeasedBackend()
test(b, "lease", true)
test(b, "ttl", true)
test(b, "lease", "1h", true)
test(b, "ttl", "5", true)
b = testPassthroughBackend()
test(b, "lease", false)
test(b, "ttl", false)
test(b, "lease", int64(10), false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test for passing 10 directly, as opposed to int64(10)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this ensures that the interface value is the right type. Otherwise this defaults to an int.

test(b, "ttl", "40s", false)
}

func TestPassthroughBackend_Delete(t *testing.T) {
Expand Down