From dc977a85d22d1bb5bfd97f593722cabf9e428c8b Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Thu, 10 Jun 2021 12:50:02 -0500 Subject: [PATCH 01/11] Go recorded test framework --- sdk/internal/go.mod | 9 +- sdk/internal/go.sum | 37 +- sdk/internal/testframework/recording.go | 414 ++++++++++++++++++ .../testframework/recording_sanitizer.go | 83 ++++ .../testframework/recording_sanitizer_test.go | 157 +++++++ sdk/internal/testframework/recording_test.go | 356 +++++++++++++++ sdk/internal/testframework/request_matcher.go | 111 +++++ .../testframework/request_matcher_test.go | 193 ++++++++ sdk/internal/testframework/testcontext.go | 50 +++ sdk/internal/uuid/uuid.go | 14 + 10 files changed, 1412 insertions(+), 12 deletions(-) create mode 100644 sdk/internal/testframework/recording.go create mode 100644 sdk/internal/testframework/recording_sanitizer.go create mode 100644 sdk/internal/testframework/recording_sanitizer_test.go create mode 100644 sdk/internal/testframework/recording_test.go create mode 100644 sdk/internal/testframework/request_matcher.go create mode 100644 sdk/internal/testframework/request_matcher_test.go create mode 100644 sdk/internal/testframework/testcontext.go diff --git a/sdk/internal/go.mod b/sdk/internal/go.mod index 8c500ce1c435..5a5a2fb7eb04 100644 --- a/sdk/internal/go.mod +++ b/sdk/internal/go.mod @@ -2,4 +2,11 @@ module github.com/Azure/azure-sdk-for-go/sdk/internal go 1.14 -require golang.org/x/net v0.0.0-20201010224723-4f7140c49acb +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dnaeon/go-vcr v1.2.0 + github.com/stretchr/testify v1.7.0 + golang.org/x/net v0.0.0-20210610132358-84b48f89b13b + gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect +) diff --git a/sdk/internal/go.sum b/sdk/internal/go.sum index c59642cdfa12..7064208964e3 100644 --- a/sdk/internal/go.sum +++ b/sdk/internal/go.sum @@ -1,12 +1,27 @@ -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb h1:mUVeFHoDKis5nxCAzoAi7E8Ghb86EXh/RK6wtvJIqRY= -golang.org/x/net v0.0.0-20201010224723-4f7140c49acb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= +github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= +github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/net v0.0.0-20210610132358-84b48f89b13b h1:k+E048sYJHyVnsr1GDrRZWQ32D2C7lWs9JRc0bel53A= +golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sdk/internal/testframework/recording.go b/sdk/internal/testframework/recording.go new file mode 100644 index 000000000000..2b008d55f3f8 --- /dev/null +++ b/sdk/internal/testframework/recording.go @@ -0,0 +1,414 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "errors" + "fmt" + "io/ioutil" + "math/rand" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/dnaeon/go-vcr/cassette" + "github.com/dnaeon/go-vcr/recorder" + "gopkg.in/yaml.v2" +) + +type Recording struct { + SessionName string + RecordingFile string + VariablesFile string + Mode RecordMode + variables map[string]*string `yaml:"variables"` + previousSessionVariables map[string]*string `yaml:"variables"` + recorder *recorder.Recorder + src rand.Source + now *time.Time + Sanitizer *RecordingSanitizer + c TestContext +} + +const ( + alphanumericBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" + alphanumericLowercaseBytes = "abcdefghijklmnopqrstuvwxyz1234567890" + randomSeedVariableName = "randomSeed" + nowVariableName = "now" + ModeEnvironmentVariableName = "AZURE_TEST_MODE" +) + +// Inspired by https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go +const ( + letterIdxBits = 6 // 6 bits to represent a letter index + letterIdxMask = 1< 0 { + // Merge values from previousVariables that are not in variables to variables + for k, v := range r.previousSessionVariables { + if _, ok := r.variables[k]; ok { + // skip variables that were new in the current session + continue + } + r.variables[k] = v + } + + // Marshal to YAML and save variables + data, err := yaml.Marshal(r.variables) + if err != nil { + return err + } + + f, err := r.createVariablesFileIfNotExists() + if err != nil { + return err + } + + defer f.Close() + + // http://www.yaml.org/spec/1.2/spec.html#id2760395 + _, err = f.Write([]byte("---\n")) + if err != nil { + return err + } + + _, err = f.Write(data) + if err != nil { + return err + } + } + return nil +} + +func (r *Recording) Now() time.Time { + r.initNow() + + return *r.now +} + +func (r *Recording) UUID() uuid.UUID { + r.initRandomSource() + + return uuid.FromSource(r.src) +} + +// GenerateAlphaNumericID will generate a recorded random alpha numeric id +// if the recording has a randomSeed already set, the value will be generated from that seed, else a new random seed will be used +func (r *Recording) GenerateAlphaNumericID(prefix string, length int, lowercaseOnly bool) (string, error) { + + if length <= len(prefix) { + return "", errors.New("length must be greater than prefix") + } + + r.initRandomSource() + + sb := strings.Builder{} + sb.Grow(length) + sb.WriteString(prefix) + i := length - len(prefix) - 1 + // A src.Int63() generates 63 random bits, enough for letterIdxMax characters! + for cache, remain := r.src.Int63(), letterIdxMax; i >= 0; { + if remain == 0 { + cache, remain = r.src.Int63(), letterIdxMax + } + if lowercaseOnly { + if idx := int(cache & letterIdxMask); idx < len(alphanumericLowercaseBytes) { + sb.WriteByte(alphanumericLowercaseBytes[idx]) + i-- + } + } else { + if idx := int(cache & letterIdxMask); idx < len(alphanumericBytes) { + sb.WriteByte(alphanumericBytes[idx]) + i-- + } + } + cache >>= letterIdxBits + remain-- + } + str := sb.String() + return str, nil +} + +// getRequiredEnv gets an environment variable by name and returns an error if it is not found +func getRequiredEnv(name string) (*string, error) { + env, ok := os.LookupEnv(name) + if ok { + return &env, nil + } else { + return nil, errors.New(envNotExistsError(name)) + } +} + +// getOptionalEnv gets an environment variable by name and returns the defaultValue if not found +func getOptionalEnv(name string, defaultValue string) *string { + env, ok := os.LookupEnv(name) + if ok { + return &env + } else { + return &defaultValue + } +} + +func (r *Recording) matchRequest(req *http.Request, rec cassette.Request) bool { + isMatch := compareMethods(req, rec, r.c) && + compareURLs(req, rec, r.c) && + compareHeaders(req, rec, r.c) && + compareBodies(req, rec, r.c) + + return isMatch +} + +func missingRequestError(req *http.Request) string { + reqUrl := req.URL.String() + return fmt.Sprintf("\nNo matching recorded request found.\nRequest: [%s] %s\n", req.Method, reqUrl) +} + +func envNotExistsError(varName string) string { + return "Required environment variable not set: " + varName +} + +// applyVariableOptions applies the VariableType transform to the value +// If variableType is not provided or Default, return result +// If variableType is Secret_String, return SanitizedValue +// If variableType isSecret_Base64String return SanitizedBase64Value +func applyVariableOptions(val *string, variableType VariableType) *string { + var ret string + + switch variableType { + case Secret_String: + ret = SanitizedValue + return &ret + case Secret_Base64String: + ret = SanitizedBase64Value + return &ret + default: + return val + } +} + +// initRandomSource initializes the Source to be used for random value creation in this Recording +func (r *Recording) initRandomSource() { + // if we already have a Source generated, return immediately + if r.src != nil { + return + } + + var seed int64 + var err error + + // check to see if we already have a random seed stored, use that if so + seedString, ok := r.previousSessionVariables[randomSeedVariableName] + if ok { + seed, err = strconv.ParseInt(*seedString, 10, 64) + } + + // We did not have a random seed already stored; create a new one + if !ok || err != nil || r.Mode == Live { + seed = time.Now().Unix() + val := strconv.FormatInt(seed, 10) + r.variables[randomSeedVariableName] = &val + } + + // create a Source with the seed + r.src = rand.NewSource(seed) +} + +// initNow initializes the Source to be used for random value creation in this Recording +func (r *Recording) initNow() { + // if we already have a now generated, return immediately + if r.now != nil { + return + } + + var err error + var nowStr *string + var newNow time.Time + + // check to see if we already have a random seed stored, use that if so + nowStr, ok := r.previousSessionVariables[nowVariableName] + if ok { + newNow, err = time.Parse(time.RFC3339Nano, *nowStr) + } + + // We did not have a random seed already stored; create a new one + if !ok || err != nil || r.Mode == Live { + newNow = time.Now() + nowStr = new(string) + *nowStr = newNow.Format(time.RFC3339Nano) + r.variables[nowVariableName] = nowStr + } + + // save the now value. + r.now = &newNow +} + +// getFilePaths returns (recordingFilePath, variablesFilePath) +func getFilePaths(name string) (string, string) { + recPath := "recordings/" + name + varPath := fmt.Sprintf("%s-variables.yaml", recPath) + return recPath, varPath +} + +// createVariablesFileIfNotExists calls os.Create on the VariablesFile and creates it if it or the path does not exist +// Callers must call Close on the result +func (r *Recording) createVariablesFileIfNotExists() (*os.File, error) { + f, err := os.Create(r.VariablesFile) + if err != nil { + if !os.IsNotExist(err) { + return nil, err + } + // Create directory for the variables if missing + variablesDir := filepath.Dir(r.VariablesFile) + if _, err := os.Stat(variablesDir); os.IsNotExist(err) { + if err = os.MkdirAll(variablesDir, 0755); err != nil { + return nil, err + } + } + + f, err = os.Create(r.VariablesFile) + if err != nil { + return nil, err + } + } + + return f, nil +} + +func (r *Recording) unmarshalVariablesFile(out interface{}) error { + data, err := ioutil.ReadFile(r.VariablesFile) + if err != nil { + // If the file or dir do not exist, this is not an error to report + if os.IsNotExist(err) { + r.c.Log(fmt.Sprintf("Did not find recording for test '%s'", r.RecordingFile)) + return nil + } else { + return err + } + } else { + err = yaml.Unmarshal(data, out) + } + return nil +} + +func (r *Recording) initVariables() error { + return r.unmarshalVariablesFile(r.previousSessionVariables) +} + +var modeMap = map[RecordMode]recorder.Mode{ + Record: recorder.ModeRecording, + Live: recorder.ModeDisabled, + Playback: recorder.ModeReplaying, +} diff --git a/sdk/internal/testframework/recording_sanitizer.go b/sdk/internal/testframework/recording_sanitizer.go new file mode 100644 index 000000000000..873844b7703b --- /dev/null +++ b/sdk/internal/testframework/recording_sanitizer.go @@ -0,0 +1,83 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "net/http" + + "github.com/dnaeon/go-vcr/cassette" + "github.com/dnaeon/go-vcr/recorder" +) + +type RecordingSanitizer struct { + recorder *recorder.Recorder + headersToSanitize map[string]*string + urlSanitizer StringSanitizer + bodySanitizer StringSanitizer +} + +type StringSanitizer func(*string) + +const SanitizedValue string = "sanitized" +const SanitizedBase64Value string = "Kg==" + +var sanitizedValueSlice = []string{SanitizedValue} + +func DefaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { + // The default sanitizer sanitizes the Authorization header + s := &RecordingSanitizer{headersToSanitize: map[string]*string{"Authorization": nil}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} + recorder.AddSaveFilter(s.applySaveFilter) + + return s +} + +// AddSanitizedHeaders adds the supplied header names to the list of headers to be sanitized on request and response recordings. +func (s *RecordingSanitizer) AddSanitizedHeaders(headers ...string) { + for _, headerName := range headers { + s.headersToSanitize[headerName] = nil + } +} + +// AddBodysanitizer configures the supplied StringSanitizer to sanitize recording request and response bodies +func (s *RecordingSanitizer) AddBodysanitizer(sanitizer StringSanitizer) { + s.bodySanitizer = sanitizer +} + +// AddUriSanitizer configures the supplied StringSanitizer to sanitize recording request and response URLs +func (s *RecordingSanitizer) AddUrlSanitizer(sanitizer StringSanitizer) { + s.urlSanitizer = sanitizer +} + +func (s *RecordingSanitizer) sanitizeHeaders(header http.Header) { + for headerName := range s.headersToSanitize { + if _, ok := header[headerName]; ok { + header[headerName] = sanitizedValueSlice + } + } +} + +func (s *RecordingSanitizer) sanitizeBodies(body *string) { + s.bodySanitizer(body) +} + +func (s *RecordingSanitizer) sanitizeURL(url *string) { + s.urlSanitizer(url) +} + +func (s *RecordingSanitizer) applySaveFilter(i *cassette.Interaction) error { + s.sanitizeHeaders(i.Request.Headers) + s.sanitizeHeaders(i.Response.Headers) + s.sanitizeURL(&i.Request.URL) + if len(i.Request.Body) > 0 { + s.sanitizeBodies(&i.Request.Body) + } + if len(i.Response.Body) > 0 { + s.sanitizeBodies(&i.Response.Body) + } + return nil +} + +func DefaultStringSanitizer(s *string) {} diff --git a/sdk/internal/testframework/recording_sanitizer_test.go b/sdk/internal/testframework/recording_sanitizer_test.go new file mode 100644 index 000000000000..570dfb3b005b --- /dev/null +++ b/sdk/internal/testframework/recording_sanitizer_test.go @@ -0,0 +1,157 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "net/http" + "os" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/dnaeon/go-vcr/cassette" + "github.com/dnaeon/go-vcr/recorder" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type recordingSanitizerTests struct { + suite.Suite +} + +const authHeader string = "Authorization" +const customHeader1 string = "Fooheader" +const customHeader2 string = "Barheader" +const nonSanitizedHeader string = "notsanitized" + +func TestRecordingSanitizer(t *testing.T) { + suite.Run(t, new(recordingSanitizerTests)) +} + +func (s *recordingSanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() { + assert := assert.New(s.T()) + server, cleanup := mock.NewServer() + server.SetResponse() + defer cleanup() + rt := NewMockRoundTripper(server) + r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) + + DefaultSanitizer(r) + + req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) + req.Header.Add(authHeader, "superSecret") + + r.RoundTrip(req) + r.Stop() + + assert.Equal(SanitizedValue, req.Header.Get(authHeader)) + + rec, err := cassette.Load(getTestFileName(s.T(), false)) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.Equal(SanitizedValue, i.Request.Headers.Get(authHeader)) + } +} + +func (s *recordingSanitizerTests) TestAddSanitizedHeadersSanitizes() { + assert := assert.New(s.T()) + server, cleanup := mock.NewServer() + server.SetResponse() + defer cleanup() + rt := NewMockRoundTripper(server) + r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) + + target := DefaultSanitizer(r) + target.AddSanitizedHeaders(customHeader1, customHeader2) + + req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) + req.Header.Add(customHeader1, "superSecret") + req.Header.Add(customHeader2, "verySecret") + safeValue := "safeValue" + req.Header.Add(nonSanitizedHeader, safeValue) + + r.RoundTrip(req) + r.Stop() + + assert.Equal(SanitizedValue, req.Header.Get(customHeader1)) + assert.Equal(SanitizedValue, req.Header.Get(customHeader2)) + assert.Equal(safeValue, req.Header.Get(nonSanitizedHeader)) + + rec, err := cassette.Load(getTestFileName(s.T(), false)) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.Equal(SanitizedValue, i.Request.Headers.Get(customHeader1)) + assert.Equal(SanitizedValue, i.Request.Headers.Get(customHeader2)) + assert.Equal(safeValue, i.Request.Headers.Get(nonSanitizedHeader)) + } +} + +func (s *recordingSanitizerTests) TestAddUrlSanitizerSanitizes() { + assert := assert.New(s.T()) + secret := "secretvalue" + secretBody := "some body content that contains a " + secret + server, cleanup := mock.NewServer() + server.SetResponse(mock.WithStatusCode(http.StatusCreated), mock.WithBody([]byte(secretBody))) + defer cleanup() + rt := NewMockRoundTripper(server) + r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) + + baseUrl := server.URL() + "/" + + target := DefaultSanitizer(r) + target.AddUrlSanitizer(func(url *string) { + *url = strings.Replace(*url, secret, SanitizedValue, -1) + }) + target.AddBodysanitizer(func(body *string) { + *body = strings.Replace(*body, secret, SanitizedValue, -1) + }) + + req, _ := http.NewRequest(http.MethodPost, baseUrl+secret, closerFromString(secretBody)) + + r.RoundTrip(req) + r.Stop() + + rec, err := cassette.Load(getTestFileName(s.T(), false)) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.NotContains(i.Response.Body, secret) + assert.NotContains(i.Request.URL, secret) + assert.NotContains(i.Request.Body, secret) + assert.Contains(i.Request.URL, SanitizedValue) + assert.Contains(i.Request.Body, SanitizedValue) + assert.Contains(i.Response.Body, SanitizedValue) + } +} + +func (s *recordingSanitizerTests) TearDownSuite() { + assert := assert.New(s.T()) + // cleanup test files + err := os.RemoveAll("testfiles") + assert.Nil(err) +} + +func getTestFileName(t *testing.T, addSuffix bool) string { + name := "testfiles/" + t.Name() + if addSuffix { + name = name + ".yaml" + } + return name +} + +type mockRoundTripper struct { + server *mock.Server +} + +func NewMockRoundTripper(server *mock.Server) *mockRoundTripper { + return &mockRoundTripper{server: server} +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return m.server.Do(req) +} diff --git a/sdk/internal/testframework/recording_test.go b/sdk/internal/testframework/recording_test.go new file mode 100644 index 000000000000..4ed73f5805ce --- /dev/null +++ b/sdk/internal/testframework/recording_test.go @@ -0,0 +1,356 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "fmt" + "io/ioutil" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" + "github.com/dnaeon/go-vcr/cassette" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type recordingTests struct { + suite.Suite +} + +func TestRecording(t *testing.T) { + suite.Run(t, new(recordingTests)) +} + +func (s *recordingTests) TestInitializeRecording() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + expectedMode := Playback + + target, err := NewRecording(context, expectedMode) + assert.Nil(err) + assert.NotNil(target.RecordingFile) + assert.NotNil(target.VariablesFile) + assert.Equal(expectedMode, target.Mode) + + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestStopDoesNotSaveVariablesWhenNoVariablesExist() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + err = target.Stop() + assert.Nil(err) + + _, err = ioutil.ReadFile(target.VariablesFile) + assert.Equal(true, os.IsNotExist(err)) +} + +func (s *recordingTests) TestRecordedVariables() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { s.T().Log(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + nonExistingEnvVar := "nonExistingEnvVar" + expectedVariableValue := "foobar" + variablesMap := map[string]string{} + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + // optional variables always succeed. + assert.Equal(expectedVariableValue, target.GetOptionalRecordedVariable(nonExistingEnvVar, expectedVariableValue, Default)) + + // non existent variables return an error + val, err := target.GetRecordedVariable(nonExistingEnvVar, Default) + // mark test as succeeded + assert.Equal(envNotExistsError(nonExistingEnvVar), err.Error()) + + // now create the env variable and check that it can be fetched + os.Setenv(nonExistingEnvVar, expectedVariableValue) + defer os.Unsetenv(nonExistingEnvVar) + val, err = target.GetRecordedVariable(nonExistingEnvVar, Default) + assert.Equal(expectedVariableValue, val) + + err = target.Stop() + assert.Nil(err) + + // check that a variables file was created with the correct variable + target.unmarshalVariablesFile(variablesMap) + actualValue, ok := variablesMap[nonExistingEnvVar] + assert.Equal(true, ok) + assert.Equal(expectedVariableValue, actualValue) +} + +func (s *recordingTests) TestRecordedVariablesSanitized() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + SanitizedStringVar := "sanitizedvar" + SanitizedBase64StrigVar := "sanitizedbase64var" + secret := "secretstring" + secretBase64 := "asdfasdf==" + variablesMap := map[string]string{} + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + // call GetOptionalRecordedVariable with the Secret_String VariableType arg + assert.Equal(secret, target.GetOptionalRecordedVariable(SanitizedStringVar, secret, Secret_String)) + + // call GetOptionalRecordedVariable with the Secret_Base64String VariableType arg + assert.Equal(secretBase64, target.GetOptionalRecordedVariable(SanitizedBase64StrigVar, secretBase64, Secret_Base64String)) + + // Calling Stop will save the variables and apply the sanitization options + err = target.Stop() + assert.Nil(err) + + // check that a variables file was created with the correct variables + target.unmarshalVariablesFile(variablesMap) + actualValue, ok := variablesMap[SanitizedStringVar] + assert.Equal(true, ok) + // the saved value is sanitized + assert.Equal(SanitizedValue, actualValue) + + target.unmarshalVariablesFile(variablesMap) + actualValue, ok = variablesMap[SanitizedBase64StrigVar] + assert.Equal(true, ok) + // the saved value is sanitized + assert.Equal(SanitizedBase64Value, actualValue) +} + +func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + expectedVariableName := "someVariable" + expectedVariableValue := "foobar" + addedVariableName := "addedVariable" + addedVariableValue := "fizzbuzz" + variablesMap := map[string]string{} + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + target.GetOptionalRecordedVariable(expectedVariableName, expectedVariableValue, Default) + + err = target.Stop() + assert.Nil(err) + + // check that a variables file was created with the correct variable + target.unmarshalVariablesFile(variablesMap) + actualValue, ok := variablesMap[expectedVariableName] + assert.True(ok) + assert.Equal(expectedVariableValue, actualValue) + + variablesMap = map[string]string{} + target2, err := NewRecording(context, Playback) + assert.Nil(err) + + // add a new variable to the existing batch + target2.GetOptionalRecordedVariable(addedVariableName, addedVariableValue, Default) + + err = target2.Stop() + assert.Nil(err) + + // check that a variables file was created with the variables loaded from the previous recording + target2.unmarshalVariablesFile(variablesMap) + actualValue, ok = variablesMap[addedVariableName] + assert.Truef(ok, fmt.Sprintf("Should have found %s", addedVariableName)) + assert.Equal(addedVariableValue, actualValue) + actualValue, ok = variablesMap[expectedVariableName] + assert.Truef(ok, fmt.Sprintf("Should have found %s", expectedVariableName)) + assert.Equal(expectedVariableValue, actualValue) +} + +func (s *recordingTests) TestUUID() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + recordedUUID1 := target.UUID() + recordedUUID1a := target.UUID() + assert.NotEqual(recordedUUID1.String(), recordedUUID1a.String()) + + err = target.Stop() + assert.Nil(err) + + target2, err := NewRecording(context, Playback) + assert.Nil(err) + + recordedUUID2 := target2.UUID() + + // The two generated UUIDs should be the same since target2 loaded the saved random seed from target + assert.Equal(recordedUUID1.String(), recordedUUID2.String()) + + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestNow() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + recordedNow1 := target.Now() + + time.Sleep(time.Millisecond * 100) + + recordedNow1a := target.Now() + assert.Equal(recordedNow1.UnixNano(), recordedNow1a.UnixNano()) + + err = target.Stop() + assert.Nil(err) + + target2, err := NewRecording(context, Playback) + assert.Nil(err) + + recordedNow2 := target2.Now() + + // The two generated nows should be the same since target2 loaded the saved random seed from target + assert.Equal(recordedNow1.UnixNano(), recordedNow2.UnixNano()) + + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestGenerateAlphaNumericID() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + prefix := "myprefix" + + target, err := NewRecording(context, Playback) + assert.Nil(err) + + generated1, err := target.GenerateAlphaNumericID(prefix, 10, true) + + assert.Equal(10, len(generated1)) + assert.Equal(true, strings.HasPrefix(generated1, prefix)) + + generated1a, err := target.GenerateAlphaNumericID(prefix, 10, true) + assert.NotEqual(generated1, generated1a) + + err = target.Stop() + assert.Nil(err) + + target2, err := NewRecording(context, Playback) + assert.Nil(err) + + generated2, err := target2.GenerateAlphaNumericID(prefix, 10, true) + + // The two generated Ids should be the same since target2 loaded the saved random seed from target + assert.Equal(generated2, generated1) + + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestRecordRequestsAndDoMatching() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + server, cleanup := mock.NewServer() + server.SetResponse() + defer cleanup() + rt := NewMockRoundTripper(server) + + target, err := NewRecording(context, Playback) + target.recorder.SetTransport(rt) + + path, err := target.GenerateAlphaNumericID("", 5, true) + reqUrl := server.URL() + "/" + path + + req, _ := http.NewRequest(http.MethodPost, reqUrl, nil) + + // record the request + target.Do(req) + err = target.Stop() + assert.Nil(err) + + rec, err := cassette.Load(target.SessionName) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.Equal(reqUrl, i.Request.URL) + } + + // re-initialize the recording + target, err = NewRecording(context, Playback) + target.recorder.SetTransport(rt) + + // re-create the random url using the recorded variables + path, err = target.GenerateAlphaNumericID("", 5, true) + reqUrl = server.URL() + "/" + path + req, _ = http.NewRequest(http.MethodPost, reqUrl, nil) + + // playback the request + target.Do(req) + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TestRecordRequestsAndFailMatchingForMissingRecording() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { s.T().Log(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + server, cleanup := mock.NewServer() + server.SetResponse() + defer cleanup() + rt := NewMockRoundTripper(server) + + target, err := NewRecording(context, Playback) + target.recorder.SetTransport(rt) + + path, err := target.GenerateAlphaNumericID("", 5, true) + reqUrl := server.URL() + "/" + path + + req, _ := http.NewRequest(http.MethodPost, reqUrl, nil) + + // record the request + target.Do(req) + err = target.Stop() + assert.Nil(err) + + rec, err := cassette.Load(target.SessionName) + assert.Nil(err) + + for _, i := range rec.Interactions { + assert.Equal(reqUrl, i.Request.URL) + } + + // re-initialize the recording + target, err = NewRecording(context, Playback) + target.recorder.SetTransport(rt) + + // re-create the random url using the recorded variables + reqUrl = server.URL() + "/" + "mismatchedRequest" + req, _ = http.NewRequest(http.MethodPost, reqUrl, nil) + + // playback the request + _, err = target.Do(req) + assert.Equal(missingRequestError(req), err.Error()) + // mark succeeded + err = target.Stop() + assert.Nil(err) +} + +func (s *recordingTests) TearDownSuite() { + + // cleanup test files + err := os.RemoveAll("recordings") + assert.Nil(s.T(), err) +} diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go new file mode 100644 index 000000000000..38997a5ffa50 --- /dev/null +++ b/sdk/internal/testframework/request_matcher.go @@ -0,0 +1,111 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "bytes" + "fmt" + "io/ioutil" + "net/http" + "reflect" + + "github.com/dnaeon/go-vcr/cassette" +) + +type RequestMatcher struct { + ignoredHeaders map[string]*string +} + +var ignoredHeaders = map[string]*string{ + "Date": nil, + "X-Ms-Date": nil, + "x-ms-date": nil, + "x-ms-client-request-id": nil, + "User-Agent": nil, + "Request-Id": nil, + "traceparent": nil, + "Authorization": nil, +} + +var recordingHeaderMissing = "Test recording headers do not match. Header '%s' is present in request but not in recording." +var requestHeaderMissing = "Test recording headers do not match. Header '%s' is present in recording but not in request." +var headerValuesMismatch = "Test recording header '%s' does not match. request: %s, recording: %s" +var methodMismatch = "Test recording methods do not match. request: %s, recording: %s" +var urlMismatch = "Test recording URLs do not match. request: %s, recording: %s" +var bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" + +func compareBodies(r *http.Request, i cassette.Request, c TestContext) bool { + body := bytes.Buffer{} + if r.Body != nil { + _, err := body.ReadFrom(r.Body) + if err != nil { + return false + } + r.Body = ioutil.NopCloser(&body) + } + bodiesMatch := body.String() == i.Body + if !bodiesMatch { + c.Log(fmt.Sprintf(bodiesMismatch, body.String(), i.Body)) + } + return bodiesMatch +} + +func compareURLs(r *http.Request, i cassette.Request, c TestContext) bool { + if r.URL.String() != i.URL { + c.Log(fmt.Sprintf(urlMismatch, r.URL.String(), i.URL)) + return false + } + return true +} + +func compareMethods(r *http.Request, i cassette.Request, c TestContext) bool { + if r.Method != i.Method { + c.Log(fmt.Sprintf(methodMismatch, r.Method, i.Method)) + return false + } + return true +} + +func compareHeaders(r *http.Request, i cassette.Request, c TestContext) bool { + unVisitedCassetteKeys := make(map[string]*string, len(i.Headers)) + // clone the cassette keys to track which we have seen + for k := range i.Headers { + if _, ignore := ignoredHeaders[k]; ignore { + // don't copy ignored headers + continue + } + unVisitedCassetteKeys[k] = nil + } + //iterate through all the request headers to compare them to cassette headers + for key, requestHeader := range r.Header { + if _, ignore := ignoredHeaders[key]; ignore { + // this is an ignorable header + continue + } + delete(unVisitedCassetteKeys, key) + if recordedHeader, foundMatch := i.Headers[key]; foundMatch { + headersMatch := reflect.DeepEqual(requestHeader, recordedHeader) + if !headersMatch { + // headers don't match + c.Log(fmt.Sprintf(headerValuesMismatch, key, requestHeader, recordedHeader)) + return false + } + + } else { + // header not found + c.Log(fmt.Sprintf(recordingHeaderMissing, key)) + return false + } + } + if len(unVisitedCassetteKeys) > 0 { + // headers exist in the recording that do not exist in the request + for headerName := range unVisitedCassetteKeys { + c.Log(fmt.Sprintf(requestHeaderMissing, headerName)) + } + return false + } + return true +} diff --git a/sdk/internal/testframework/request_matcher_test.go b/sdk/internal/testframework/request_matcher_test.go new file mode 100644 index 000000000000..90c945458917 --- /dev/null +++ b/sdk/internal/testframework/request_matcher_test.go @@ -0,0 +1,193 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +import ( + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" + "github.com/dnaeon/go-vcr/cassette" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type requestMatcherTests struct { + suite.Suite +} + +func TestRequestMatcher(t *testing.T) { + suite.Run(t, new(requestMatcherTests)) +} + +const matchedBody string = "Matching body." +const unMatchedBody string = "This body does not match." + +func (s *requestMatcherTests) TestCompareBodies() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + req := http.Request{Body: closerFromString(matchedBody)} + recReq := cassette.Request{Body: matchedBody} + + isMatch := compareBodies(&req, recReq, context) + + assert.Equal(true, isMatch) + + // make the requests mis-match + req.Body = closerFromString((unMatchedBody)) + + isMatch = compareBodies(&req, recReq, context) + + assert.False(isMatch) +} + +func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + for headerName := range ignoredHeaders { + reqHeaders[headerName] = []string{uuid.New().String()} + recordedHeaders[headerName] = []string{uuid.New().String()} + } + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + // All headers match + assert.True(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + header1 := "header1" + headerValue := []string{"some value"} + + reqHeaders[header1] = headerValue + recordedHeaders[header1] = headerValue + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + assert.True(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + header1 := "header1" + header2 := "header2" + headerValue := []string{"some value"} + + reqHeaders[header1] = headerValue + recordedHeaders[header1] = headerValue + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + // add a new header to the just req + reqHeaders[header2] = headerValue + + assert.False(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + header1 := "header1" + header2 := "header2" + headerValue := []string{"some value"} + + reqHeaders[header1] = headerValue + recordedHeaders[header1] = headerValue + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + // add a new header to just the recording + recordedHeaders[header2] = headerValue + + assert.False(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + + // populate only ignored headers that do not match + reqHeaders := make(http.Header) + recordedHeaders := make(http.Header) + header1 := "header1" + header2 := "header2" + headerValue := []string{"some value"} + mismatch := []string{"mismatch"} + + reqHeaders[header1] = headerValue + recordedHeaders[header1] = headerValue + + req := http.Request{Header: reqHeaders} + recReq := cassette.Request{Headers: recordedHeaders} + + // header names match but values are different + recordedHeaders[header2] = headerValue + reqHeaders[header2] = mismatch + + assert.False(compareHeaders(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareURLs() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + scheme := "https" + host := "foo.bar" + req := http.Request{URL: &url.URL{Scheme: scheme, Host: host}} + recReq := cassette.Request{URL: scheme + "://" + host} + + assert.True(compareURLs(&req, recReq, context)) + + req.URL.Path = "noMatch" + + assert.False(compareURLs(&req, recReq, context)) +} + +func (s *requestMatcherTests) TestCompareMethods() { + assert := assert.New(s.T()) + context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + methodPost := "POST" + methodPatch := "PATCH" + req := http.Request{Method: methodPost} + recReq := cassette.Request{Method: methodPost} + + assert.True(compareMethods(&req, recReq, context)) + + req.Method = methodPatch + + assert.False(compareMethods(&req, recReq, context)) +} + +func closerFromString(content string) io.ReadCloser { + return ioutil.NopCloser(strings.NewReader(content)) +} diff --git a/sdk/internal/testframework/testcontext.go b/sdk/internal/testframework/testcontext.go new file mode 100644 index 000000000000..97bcc132e5ef --- /dev/null +++ b/sdk/internal/testframework/testcontext.go @@ -0,0 +1,50 @@ +// +build go1.13 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package testframework + +type TestContext interface { + Fail(string) + Log(string) + Name() string + IsFailed() bool +} + +type testContext struct { + failed bool + fail Failer + log Logger + name string +} + +type Failer func(string) +type Logger func(string) +type Name func() string + +// NewTestContext initializes a new TestContext +func NewTestContext(failer Failer, logger Logger, name Name) TestContext { + return &testContext{fail: failer, log: logger, name: name()} +} + +// Fail calls the Failer func and makes IsFailed return true. +func (c *testContext) Fail(msg string) { + c.failed = true + c.fail(msg) +} + +// Log calls the Logger func. +func (c *testContext) Log(msg string) { + c.log(msg) +} + +// Name calls the Name func and returns the result. +func (c *testContext) Name() string { + return c.name +} + +// IsFailed returns true if the Failer has been called. +func (c *testContext) IsFailed() bool { + return c.failed +} diff --git a/sdk/internal/uuid/uuid.go b/sdk/internal/uuid/uuid.go index 4b288d81fecd..2f3c55d0e633 100644 --- a/sdk/internal/uuid/uuid.go +++ b/sdk/internal/uuid/uuid.go @@ -41,6 +41,20 @@ func New() UUID { return u } +// FromSource returns a new uuid based on the supplied rand.Source as a seed. +func FromSource(src rand.Source) UUID { + u := UUID{} + // Set all bits to randomly (or pseudo-randomly) chosen values. + // math/rand.Read() is no-fail so we omit any error checking. + rnd := rand.New(src) + rnd.Read(u[:]) + u[8] = (u[8] | reservedRFC4122) & 0x7F // u.setVariant(ReservedRFC4122) + + var version byte = 4 + u[6] = (u[6] & 0xF) | (version << 4) // u.setVersion(4) + return u +} + // String returns an unparsed version of the generated UUID sequence. func (u UUID) String() string { return fmt.Sprintf("%x-%x-%x-%x-%x", u[0:4], u[4:6], u[6:8], u[8:10], u[10:]) From 9c9ded075959975dec842763028f2e0639ea7a83 Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Thu, 10 Jun 2021 16:54:33 -0500 Subject: [PATCH 02/11] extensible matchers --- sdk/internal/testframework/recording.go | 10 +- sdk/internal/testframework/request_matcher.go | 97 ++++++++++++++----- .../testframework/request_matcher_test.go | 30 +++--- 3 files changed, 98 insertions(+), 39 deletions(-) diff --git a/sdk/internal/testframework/recording.go b/sdk/internal/testframework/recording.go index 2b008d55f3f8..38e4495945be 100644 --- a/sdk/internal/testframework/recording.go +++ b/sdk/internal/testframework/recording.go @@ -34,6 +34,7 @@ type Recording struct { src rand.Source now *time.Time Sanitizer *RecordingSanitizer + Matcher *RequestMatcher c TestContext } @@ -101,6 +102,7 @@ func NewRecording(c TestContext, mode RecordMode) (*Recording, error) { } // set the recorder Matcher + recording.Matcher = DefaultMatcher(c) rec.SetMatcher(recording.matchRequest) // wire up the sanitizer @@ -263,10 +265,10 @@ func getOptionalEnv(name string, defaultValue string) *string { } func (r *Recording) matchRequest(req *http.Request, rec cassette.Request) bool { - isMatch := compareMethods(req, rec, r.c) && - compareURLs(req, rec, r.c) && - compareHeaders(req, rec, r.c) && - compareBodies(req, rec, r.c) + isMatch := r.Matcher.compareMethods(req, rec.Method) && + r.Matcher.compareURLs(req, rec.URL) && + r.Matcher.compareHeaders(req, rec) && + r.Matcher.compareBodies(req, rec.Body) return isMatch } diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go index 38997a5ffa50..fb8fc1fd22f6 100644 --- a/sdk/internal/testframework/request_matcher.go +++ b/sdk/internal/testframework/request_matcher.go @@ -16,9 +16,16 @@ import ( ) type RequestMatcher struct { - ignoredHeaders map[string]*string + context TestContext + IgnoredHeaders map[string]*string + bodyMatcher StringMatcher + urlMatcher StringMatcher + methodMatcher StringMatcher } +type StringMatcher func(reqVal string, recVal string) bool +type matcherWrapper func(matcher StringMatcher, testContext TestContext) bool + var ignoredHeaders = map[string]*string{ "Date": nil, "X-Ms-Date": nil, @@ -37,39 +44,81 @@ var methodMismatch = "Test recording methods do not match. request: %s, recordin var urlMismatch = "Test recording URLs do not match. request: %s, recording: %s" var bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" -func compareBodies(r *http.Request, i cassette.Request, c TestContext) bool { +func DefaultMatcher(testContext TestContext) *RequestMatcher { + // The default sanitizer sanitizes the Authorization header + matcher := &RequestMatcher{ + context: testContext, + IgnoredHeaders: ignoredHeaders, + } + matcher.SetBodyMatcher(defaultStringMatcher) + matcher.SetURLMatcher(defaultStringMatcher) + matcher.SetMethodMatcher(defaultStringMatcher) + + return matcher +} + +func (m *RequestMatcher) SetBodyMatcher(matcher StringMatcher) { + m.setMatcher(matcher, bodiesMismatch) +} + +func (m *RequestMatcher) SetURLMatcher(matcher StringMatcher) { + m.setMatcher(matcher, urlMismatch) +} + +func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) { + m.setMatcher(matcher, methodMismatch) +} + +func (m *RequestMatcher) setMatcher(matcher StringMatcher, message string) { + m.bodyMatcher = func(reqVal string, recVal string) bool { + isMatch := matcher(reqVal, recVal) + if !isMatch { + m.context.Log(fmt.Sprintf(message, recVal, recVal)) + } + return isMatch + } +} + +func defaultStringMatcher(s1 string, s2 string) bool { + return s1 == s2 +} + +func getBody(r *http.Request) string { body := bytes.Buffer{} if r.Body != nil { _, err := body.ReadFrom(r.Body) if err != nil { - return false + return "could not parse body: " + err.Error() } r.Body = ioutil.NopCloser(&body) } - bodiesMatch := body.String() == i.Body - if !bodiesMatch { - c.Log(fmt.Sprintf(bodiesMismatch, body.String(), i.Body)) - } - return bodiesMatch + return body.String() } -func compareURLs(r *http.Request, i cassette.Request, c TestContext) bool { - if r.URL.String() != i.URL { - c.Log(fmt.Sprintf(urlMismatch, r.URL.String(), i.URL)) - return false - } - return true +func getUrl(r *http.Request) string { + return r.URL.String() } -func compareMethods(r *http.Request, i cassette.Request, c TestContext) bool { - if r.Method != i.Method { - c.Log(fmt.Sprintf(methodMismatch, r.Method, i.Method)) - return false - } - return true +func getMethod(r *http.Request) string { + return r.Method +} + +func (m *RequestMatcher) compareBodies(r *http.Request, recordedBody string) bool { + body := getBody(r) + return m.bodyMatcher(body, recordedBody) +} + +func (m *RequestMatcher) compareURLs(r *http.Request, recordedUrl string) bool { + body := getUrl(r) + return m.urlMatcher(body, recordedUrl) +} + +func (m *RequestMatcher) compareMethods(r *http.Request, recordedMethod string) bool { + body := getMethod(r) + return m.urlMatcher(body, recordedMethod) } -func compareHeaders(r *http.Request, i cassette.Request, c TestContext) bool { +func (m *RequestMatcher) compareHeaders(r *http.Request, i cassette.Request) bool { unVisitedCassetteKeys := make(map[string]*string, len(i.Headers)) // clone the cassette keys to track which we have seen for k := range i.Headers { @@ -90,20 +139,20 @@ func compareHeaders(r *http.Request, i cassette.Request, c TestContext) bool { headersMatch := reflect.DeepEqual(requestHeader, recordedHeader) if !headersMatch { // headers don't match - c.Log(fmt.Sprintf(headerValuesMismatch, key, requestHeader, recordedHeader)) + m.context.Log(fmt.Sprintf(headerValuesMismatch, key, requestHeader, recordedHeader)) return false } } else { // header not found - c.Log(fmt.Sprintf(recordingHeaderMissing, key)) + m.context.Log(fmt.Sprintf(recordingHeaderMissing, key)) return false } } if len(unVisitedCassetteKeys) > 0 { // headers exist in the recording that do not exist in the request for headerName := range unVisitedCassetteKeys { - c.Log(fmt.Sprintf(requestHeaderMissing, headerName)) + m.context.Log(fmt.Sprintf(requestHeaderMissing, headerName)) } return false } diff --git a/sdk/internal/testframework/request_matcher_test.go b/sdk/internal/testframework/request_matcher_test.go index 90c945458917..f85e7110ee22 100644 --- a/sdk/internal/testframework/request_matcher_test.go +++ b/sdk/internal/testframework/request_matcher_test.go @@ -33,18 +33,19 @@ const unMatchedBody string = "This body does not match." func (s *requestMatcherTests) TestCompareBodies() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := DefaultMatcher(context) req := http.Request{Body: closerFromString(matchedBody)} recReq := cassette.Request{Body: matchedBody} - isMatch := compareBodies(&req, recReq, context) + isMatch := matcher.compareBodies(&req, recReq.Body) assert.Equal(true, isMatch) // make the requests mis-match req.Body = closerFromString((unMatchedBody)) - isMatch = compareBodies(&req, recReq, context) + isMatch = matcher.compareBodies(&req, recReq.Body) assert.False(isMatch) } @@ -52,6 +53,7 @@ func (s *requestMatcherTests) TestCompareBodies() { func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := DefaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -65,12 +67,13 @@ func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { recReq := cassette.Request{Headers: recordedHeaders} // All headers match - assert.True(compareHeaders(&req, recReq, context)) + assert.True(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := DefaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -84,12 +87,13 @@ func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { req := http.Request{Header: reqHeaders} recReq := cassette.Request{Headers: recordedHeaders} - assert.True(compareHeaders(&req, recReq, context)) + assert.True(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := DefaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -107,12 +111,13 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { // add a new header to the just req reqHeaders[header2] = headerValue - assert.False(compareHeaders(&req, recReq, context)) + assert.False(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := DefaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -130,12 +135,13 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { // add a new header to just the recording recordedHeaders[header2] = headerValue - assert.False(compareHeaders(&req, recReq, context)) + assert.False(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := DefaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -155,7 +161,7 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() { recordedHeaders[header2] = headerValue reqHeaders[header2] = mismatch - assert.False(compareHeaders(&req, recReq, context)) + assert.False(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareURLs() { @@ -165,12 +171,13 @@ func (s *requestMatcherTests) TestCompareURLs() { host := "foo.bar" req := http.Request{URL: &url.URL{Scheme: scheme, Host: host}} recReq := cassette.Request{URL: scheme + "://" + host} + matcher := DefaultMatcher(context) - assert.True(compareURLs(&req, recReq, context)) + assert.True(matcher.compareURLs(&req, recReq.URL)) req.URL.Path = "noMatch" - assert.False(compareURLs(&req, recReq, context)) + assert.False(matcher.compareURLs(&req, recReq.URL)) } func (s *requestMatcherTests) TestCompareMethods() { @@ -180,12 +187,13 @@ func (s *requestMatcherTests) TestCompareMethods() { methodPatch := "PATCH" req := http.Request{Method: methodPost} recReq := cassette.Request{Method: methodPost} + matcher := DefaultMatcher(context) - assert.True(compareMethods(&req, recReq, context)) + assert.True(matcher.compareMethods(&req, recReq.Method)) req.Method = methodPatch - assert.False(compareMethods(&req, recReq, context)) + assert.False(matcher.compareMethods(&req, recReq.Method)) } func closerFromString(content string) io.ReadCloser { From a7dce35a4af0d3f9fd9cdf0acf406d9ddf9c254b Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Thu, 10 Jun 2021 18:26:35 -0500 Subject: [PATCH 03/11] fix --- sdk/internal/testframework/request_matcher.go | 46 ++++++++++++------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go index fb8fc1fd22f6..3e74053af292 100644 --- a/sdk/internal/testframework/request_matcher.go +++ b/sdk/internal/testframework/request_matcher.go @@ -50,36 +50,50 @@ func DefaultMatcher(testContext TestContext) *RequestMatcher { context: testContext, IgnoredHeaders: ignoredHeaders, } - matcher.SetBodyMatcher(defaultStringMatcher) - matcher.SetURLMatcher(defaultStringMatcher) - matcher.SetMethodMatcher(defaultStringMatcher) + matcher.SetBodyMatcher(func(req string, rec string) bool { + return DefaultStringMatcher(req, rec) + }) + matcher.SetURLMatcher(func(req string, rec string) bool { + return DefaultStringMatcher(req, rec) + }) + matcher.SetMethodMatcher(func(req string, rec string) bool { + return DefaultStringMatcher(req, rec) + }) return matcher } func (m *RequestMatcher) SetBodyMatcher(matcher StringMatcher) { - m.setMatcher(matcher, bodiesMismatch) + m.bodyMatcher = func(reqVal string, recVal string) bool { + isMatch := matcher(reqVal, recVal) + if !isMatch { + m.context.Log(fmt.Sprintf(bodiesMismatch, recVal, recVal)) + } + return isMatch + } } func (m *RequestMatcher) SetURLMatcher(matcher StringMatcher) { - m.setMatcher(matcher, urlMismatch) + m.urlMatcher = func(reqVal string, recVal string) bool { + isMatch := matcher(reqVal, recVal) + if !isMatch { + m.context.Log(fmt.Sprintf(urlMismatch, recVal, recVal)) + } + return isMatch + } } func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) { - m.setMatcher(matcher, methodMismatch) -} - -func (m *RequestMatcher) setMatcher(matcher StringMatcher, message string) { - m.bodyMatcher = func(reqVal string, recVal string) bool { + m.methodMatcher = func(reqVal string, recVal string) bool { isMatch := matcher(reqVal, recVal) if !isMatch { - m.context.Log(fmt.Sprintf(message, recVal, recVal)) + m.context.Log(fmt.Sprintf(methodMismatch, recVal, recVal)) } return isMatch } } -func defaultStringMatcher(s1 string, s2 string) bool { +func DefaultStringMatcher(s1 string, s2 string) bool { return s1 == s2 } @@ -109,13 +123,13 @@ func (m *RequestMatcher) compareBodies(r *http.Request, recordedBody string) boo } func (m *RequestMatcher) compareURLs(r *http.Request, recordedUrl string) bool { - body := getUrl(r) - return m.urlMatcher(body, recordedUrl) + url := getUrl(r) + return m.urlMatcher(url, recordedUrl) } func (m *RequestMatcher) compareMethods(r *http.Request, recordedMethod string) bool { - body := getMethod(r) - return m.urlMatcher(body, recordedMethod) + method := getMethod(r) + return m.methodMatcher(method, recordedMethod) } func (m *RequestMatcher) compareHeaders(r *http.Request, i cassette.Request) bool { From d4a818bc46f7256ee9633852ee0100c5d777ddf0 Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Thu, 10 Jun 2021 18:29:31 -0500 Subject: [PATCH 04/11] const --- sdk/internal/testframework/request_matcher.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go index 3e74053af292..db58d017af4a 100644 --- a/sdk/internal/testframework/request_matcher.go +++ b/sdk/internal/testframework/request_matcher.go @@ -37,12 +37,12 @@ var ignoredHeaders = map[string]*string{ "Authorization": nil, } -var recordingHeaderMissing = "Test recording headers do not match. Header '%s' is present in request but not in recording." -var requestHeaderMissing = "Test recording headers do not match. Header '%s' is present in recording but not in request." -var headerValuesMismatch = "Test recording header '%s' does not match. request: %s, recording: %s" -var methodMismatch = "Test recording methods do not match. request: %s, recording: %s" -var urlMismatch = "Test recording URLs do not match. request: %s, recording: %s" -var bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" +const recordingHeaderMissing = "Test recording headers do not match. Header '%s' is present in request but not in recording." +const requestHeaderMissing = "Test recording headers do not match. Header '%s' is present in recording but not in request." +const headerValuesMismatch = "Test recording header '%s' does not match. request: %s, recording: %s" +const methodMismatch = "Test recording methods do not match. request: %s, recording: %s" +const urlMismatch = "Test recording URLs do not match. request: %s, recording: %s" +const bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" func DefaultMatcher(testContext TestContext) *RequestMatcher { // The default sanitizer sanitizes the Authorization header From dba733ef3bf8b6087d2cd87fbfdaad710917d72b Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Thu, 10 Jun 2021 18:41:07 -0500 Subject: [PATCH 05/11] const formatting --- sdk/internal/testframework/request_matcher.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go index db58d017af4a..c9405d03593a 100644 --- a/sdk/internal/testframework/request_matcher.go +++ b/sdk/internal/testframework/request_matcher.go @@ -37,12 +37,14 @@ var ignoredHeaders = map[string]*string{ "Authorization": nil, } -const recordingHeaderMissing = "Test recording headers do not match. Header '%s' is present in request but not in recording." -const requestHeaderMissing = "Test recording headers do not match. Header '%s' is present in recording but not in request." -const headerValuesMismatch = "Test recording header '%s' does not match. request: %s, recording: %s" -const methodMismatch = "Test recording methods do not match. request: %s, recording: %s" -const urlMismatch = "Test recording URLs do not match. request: %s, recording: %s" -const bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" +const ( + recordingHeaderMissing = "Test recording headers do not match. Header '%s' is present in request but not in recording." + requestHeaderMissing = "Test recording headers do not match. Header '%s' is present in recording but not in request." + headerValuesMismatch = "Test recording header '%s' does not match. request: %s, recording: %s" + methodMismatch = "Test recording methods do not match. request: %s, recording: %s" + urlMismatch = "Test recording URLs do not match. request: %s, recording: %s" + bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" +) func DefaultMatcher(testContext TestContext) *RequestMatcher { // The default sanitizer sanitizes the Authorization header From 921c67f7d8afeb91605c4ed912adc462da72636f Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Fri, 11 Jun 2021 09:41:17 -0500 Subject: [PATCH 06/11] readme --- sdk/internal/testframework/README.md | 161 +++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 sdk/internal/testframework/README.md diff --git a/sdk/internal/testframework/README.md b/sdk/internal/testframework/README.md new file mode 100644 index 000000000000..d545700c3535 --- /dev/null +++ b/sdk/internal/testframework/README.md @@ -0,0 +1,161 @@ +# Azure SDK for Go Recorded Test Framework + +[![Build Status](https://dev.azure.com/azure-sdk/public/_apis/build/status/go/Azure.azure-sdk-for-go?branchName=master)](https://dev.azure.com/azure-sdk/public/_build/latest?definitionId=1842&branchName=master) + +The `testframework` package makes it easy to add recorded tests to your track-2 client package. +Below are some examples that walk through setting up a recorded test end to end. + +## Examples + +### Initializing a Recording instance for a test + +The first step in instrumenting a client to interact with recorded tests is to create a `TestContext`. +This acts as the interface between the recorded test framework and your chosen test package. +In these examples we'll use testify's [assert](https://pkg.go.dev/github.com/stretchr/testify/assert), +but you can use the framework of your choice. + +In the snippet below, demonstrates an example test setup func in which we are initializing the `TestContext` +with the methods that will be invoked when your recorded test needs to Log, Fail, get the Name of the test, +or indicate that the test IsFailed. + +***Note**: an instance of TestContext should be initialized for each test.* + +```go +// a map to store our created test contexts +var clientsMap map[string]*testContext = make(map[string]*testContext) + +// recordedTestSetup is called before each test execution by the test suite's BeforeTest method +func recordedTestSetup(t *testing.T, testName string, mode testframework.RecordMode) { + var accountName string + var suffix string + var cred *SharedKeyCredential + var secret string + var uri string + assert := assert.New(t) + + // init the test framework + context := testframework.NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { t.Log(msg) }, func() string { return testName }) + //mode should be testframework.Playback. This will automatically record if no test recording is available and playback if it is. + recording, err := testframework.NewRecording(context, mode) + assert.Nil(err) +``` + +After creating the TestContext, it must be passed to a new instance of `Recording` along with the current test mode. +`Recording` is the main component of the testframework package. + +```go +//func recordedTestSetup(t *testing.T, testName string, mode testframework.RecordMode) { +// <...> + recording, err := testframework.NewRecording(context, mode) + assert.Nil(err) +``` + +### Initializing recorded variables + +A key component to recorded tests is recorded variables. +They allow creation of values that stay with the test recording so that playback of service operations is consistent. + +In the snippet below we are calling `GetRecordedVariable` to acquire details such as the service account name and +client secret to configure the client. + +```go +//func recordedTestSetup(t *testing.T, testName string, mode testframework.RecordMode) { +// <...> + accountName, err := recording.GetRecordedVariable(storageAccountNameEnvVar, testframework.Default) + suffix := recording.GetOptionalRecordedVariable(storageEndpointSuffixEnvVar, DefaultStorageSuffix, testframework.Default) + secret, err := recording.GetRecordedVariable(storageAccountKeyEnvVar, testframework.Secret_Base64String) + cred, _ := NewSharedKeyCredential(accountName, secret) + uri := storageURI(accountName, suffix) +``` + +The last step is to instrument your client by replacing its transport with your `Recording` instance. +`Recording` satisfies the `azcore.Transport` interface. + +```go +//func recordedTestSetup(t *testing.T, testName string, mode testframework.RecordMode) { +// <...> + // Set our client's HTTPClient to our recording instance. + // Optionally, we can also configure MaxRetries to -1 to avoid the default retry behavior. + client, err := NewTableServiceClient(uri, cred, &TableClientOptions{HTTPClient: recording, Retry: azcore.RetryOptions{MaxRetries: -1}}) + assert.Nil(err) + + // either return your client instance, or store it somewhere that your test can use it for test execution. + clientsMap[testName] = &testContext{client: client, recording: recording, context: &context} +} + + +func getTestContext(key string) *testContext { + return clientsMap[key] +} +``` + +### Completing the recorded test session + +After the test run completes we need to signal the `Recording` instance to save the recording. + +```go +// recordedTestTeardown fetches the context from our map based on test name and calls Stop on the Recording instance. +func recordedTestTeardown(key string) { + context, ok := clientsMap[key] + if ok && !(*context.context).IsFailed() { + context.recording.Stop() + } +} +``` + +### Setting up a test to use our Recording instance + +Test frameworks like testify suite allow for configuration of a `BeforeTest` method to be executed before each test. +We can use this to call our `recordedTestSetup` method + +Below is an example test setup which executes a single test. + +```go +package aztable + +import ( + "errors" + "fmt" + "net/http" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/internal/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/internal/testframework" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type tableServiceClientLiveTests struct { + suite.Suite + mode testframework.RecordMode +} + +// Hookup to the testing framework +func TestServiceClient_Storage(t *testing.T) { + storage := tableServiceClientLiveTests{endpointType: StorageEndpoint, mode: testframework.Playback /* change to Record to re-record tests */} + suite.Run(t, &storage) +} + +func (s *tableServiceClientLiveTests) TestCreateTable() { + assert := assert.New(s.T()) + context := getTestContext(s.T().Name()) + // generate a random recorded value for our table name. + tableName, err := context.recording.GenerateAlphaNumericID(tableNamePrefix, 20, true) + + resp, err := context.client.Create(ctx, tableName) + defer context.client.Delete(ctx, tableName) + + assert.Nil(err) + assert.Equal(*resp.TableResponse.TableName, tableName) +} + +func (s *tableServiceClientLiveTests) BeforeTest(suite string, test string) { + // setup the test environment + recordedTestSetup(s.T(), s.T().Name(), s.endpointType, s.mode) +} + +func (s *tableServiceClientLiveTests) AfterTest(suite string, test string) { + // teardown the test context + recordedTestTeardown(s.T().Name()) +} +``` From 438d7e1df788ae0ce75e3f268b2465cc65a0c2d1 Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Tue, 15 Jun 2021 08:47:28 -0500 Subject: [PATCH 07/11] fb --- sdk/internal/testframework/README.md | 15 +++++--- sdk/internal/testframework/recording.go | 24 +++++++----- .../testframework/recording_sanitizer.go | 4 +- .../testframework/recording_sanitizer_test.go | 6 +-- sdk/internal/testframework/recording_test.go | 14 +++---- sdk/internal/testframework/request_matcher.go | 38 +++++++++++-------- .../testframework/request_matcher_test.go | 16 ++++---- 7 files changed, 67 insertions(+), 50 deletions(-) diff --git a/sdk/internal/testframework/README.md b/sdk/internal/testframework/README.md index d545700c3535..a3a496f048bd 100644 --- a/sdk/internal/testframework/README.md +++ b/sdk/internal/testframework/README.md @@ -21,8 +21,13 @@ or indicate that the test IsFailed. ***Note**: an instance of TestContext should be initialized for each test.* ```go +type testState struct { + recording *testframework.Recording + client *TableServiceClient + context *testframework.TestContext +} // a map to store our created test contexts -var clientsMap map[string]*testContext = make(map[string]*testContext) +var clientsMap map[string]*testState = make(map[string]*testState) // recordedTestSetup is called before each test execution by the test suite's BeforeTest method func recordedTestSetup(t *testing.T, testName string, mode testframework.RecordMode) { @@ -80,11 +85,11 @@ The last step is to instrument your client by replacing its transport with your assert.Nil(err) // either return your client instance, or store it somewhere that your test can use it for test execution. - clientsMap[testName] = &testContext{client: client, recording: recording, context: &context} + clientsMap[testName] = &testState{client: client, recording: recording, context: &context} } -func getTestContext(key string) *testContext { +func getTestState(key string) *testState { return clientsMap[key] } ``` @@ -132,7 +137,7 @@ type tableServiceClientLiveTests struct { // Hookup to the testing framework func TestServiceClient_Storage(t *testing.T) { - storage := tableServiceClientLiveTests{endpointType: StorageEndpoint, mode: testframework.Playback /* change to Record to re-record tests */} + storage := tableServiceClientLiveTests{mode: testframework.Playback /* change to Record to re-record tests */} suite.Run(t, &storage) } @@ -151,7 +156,7 @@ func (s *tableServiceClientLiveTests) TestCreateTable() { func (s *tableServiceClientLiveTests) BeforeTest(suite string, test string) { // setup the test environment - recordedTestSetup(s.T(), s.T().Name(), s.endpointType, s.mode) + recordedTestSetup(s.T(), s.T().Name(), s.mode) } func (s *tableServiceClientLiveTests) AfterTest(suite string, test string) { diff --git a/sdk/internal/testframework/recording.go b/sdk/internal/testframework/recording.go index 38e4495945be..9a74549a1568 100644 --- a/sdk/internal/testframework/recording.go +++ b/sdk/internal/testframework/recording.go @@ -64,8 +64,11 @@ const ( type VariableType string const ( - Default VariableType = "default" - Secret_String VariableType = "secret_string" + // NoSanitization indicates that the recorded value should not be sanitized. + NoSanitization VariableType = "default" + // Secret_String indicates that the recorded value should be replaced with a sanitized value. + Secret_String VariableType = "secret_string" + // Secret_Base64String indicates that the recorded value should be replaced with a sanitized valid base-64 string value. Secret_Base64String VariableType = "secret_base64String" ) @@ -102,18 +105,18 @@ func NewRecording(c TestContext, mode RecordMode) (*Recording, error) { } // set the recorder Matcher - recording.Matcher = DefaultMatcher(c) + recording.Matcher = defaultMatcher(c) rec.SetMatcher(recording.matchRequest) // wire up the sanitizer - recording.Sanitizer = DefaultSanitizer(rec) + recording.Sanitizer = defaultSanitizer(rec) return recording, err } -// GetRecordedVariable returns a recorded variable. If the variable is not found we return an error -// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation. -func (r *Recording) GetRecordedVariable(name string, variableType VariableType) (string, error) { +// GetEnvVar returns a recorded environment variable. If the variable is not found we return an error. +// variableType determines how the recorded variable will be saved. +func (r *Recording) GetEnvVar(name string, variableType VariableType) (string, error) { var err error result, ok := r.previousSessionVariables[name] if !ok || r.Mode == Live { @@ -128,9 +131,10 @@ func (r *Recording) GetRecordedVariable(name string, variableType VariableType) return *result, err } -// GetOptionalRecordedVariable returns a recorded variable with a fallback default value -// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation. -func (r *Recording) GetOptionalRecordedVariable(name string, defaultValue string, variableType VariableType) string { +// GetOptionalEnvVar returns a recorded environment variable with a fallback default value. +// default Value configures the fallback value to be returned if the environment variable is not set. +// variableType determines how the recorded variable will be saved. +func (r *Recording) GetOptionalEnvVar(name string, defaultValue string, variableType VariableType) string { result, ok := r.previousSessionVariables[name] if !ok || r.Mode == Live { result = getOptionalEnv(name, defaultValue) diff --git a/sdk/internal/testframework/recording_sanitizer.go b/sdk/internal/testframework/recording_sanitizer.go index 873844b7703b..16d728688d83 100644 --- a/sdk/internal/testframework/recording_sanitizer.go +++ b/sdk/internal/testframework/recording_sanitizer.go @@ -26,7 +26,9 @@ const SanitizedBase64Value string = "Kg==" var sanitizedValueSlice = []string{SanitizedValue} -func DefaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { +// defaultSanitizer returns a new RecordingSanitizer with the default sanitizing behavior. +// To customize sanitization, call AddSanitizedHeaders, AddBodySanitizer, or AddUrlSanitizer. +func defaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { // The default sanitizer sanitizes the Authorization header s := &RecordingSanitizer{headersToSanitize: map[string]*string{"Authorization": nil}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} recorder.AddSaveFilter(s.applySaveFilter) diff --git a/sdk/internal/testframework/recording_sanitizer_test.go b/sdk/internal/testframework/recording_sanitizer_test.go index 570dfb3b005b..23adad100ea4 100644 --- a/sdk/internal/testframework/recording_sanitizer_test.go +++ b/sdk/internal/testframework/recording_sanitizer_test.go @@ -39,7 +39,7 @@ func (s *recordingSanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() { rt := NewMockRoundTripper(server) r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) - DefaultSanitizer(r) + defaultSanitizer(r) req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) req.Header.Add(authHeader, "superSecret") @@ -65,7 +65,7 @@ func (s *recordingSanitizerTests) TestAddSanitizedHeadersSanitizes() { rt := NewMockRoundTripper(server) r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) - target := DefaultSanitizer(r) + target := defaultSanitizer(r) target.AddSanitizedHeaders(customHeader1, customHeader2) req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) @@ -103,7 +103,7 @@ func (s *recordingSanitizerTests) TestAddUrlSanitizerSanitizes() { baseUrl := server.URL() + "/" - target := DefaultSanitizer(r) + target := defaultSanitizer(r) target.AddUrlSanitizer(func(url *string) { *url = strings.Replace(*url, secret, SanitizedValue, -1) }) diff --git a/sdk/internal/testframework/recording_test.go b/sdk/internal/testframework/recording_test.go index 4ed73f5805ce..15e4ea2e4b04 100644 --- a/sdk/internal/testframework/recording_test.go +++ b/sdk/internal/testframework/recording_test.go @@ -70,17 +70,17 @@ func (s *recordingTests) TestRecordedVariables() { assert.Nil(err) // optional variables always succeed. - assert.Equal(expectedVariableValue, target.GetOptionalRecordedVariable(nonExistingEnvVar, expectedVariableValue, Default)) + assert.Equal(expectedVariableValue, target.GetOptionalEnvVar(nonExistingEnvVar, expectedVariableValue, NoSanitization)) // non existent variables return an error - val, err := target.GetRecordedVariable(nonExistingEnvVar, Default) + val, err := target.GetEnvVar(nonExistingEnvVar, NoSanitization) // mark test as succeeded assert.Equal(envNotExistsError(nonExistingEnvVar), err.Error()) // now create the env variable and check that it can be fetched os.Setenv(nonExistingEnvVar, expectedVariableValue) defer os.Unsetenv(nonExistingEnvVar) - val, err = target.GetRecordedVariable(nonExistingEnvVar, Default) + val, err = target.GetEnvVar(nonExistingEnvVar, NoSanitization) assert.Equal(expectedVariableValue, val) err = target.Stop() @@ -107,10 +107,10 @@ func (s *recordingTests) TestRecordedVariablesSanitized() { assert.Nil(err) // call GetOptionalRecordedVariable with the Secret_String VariableType arg - assert.Equal(secret, target.GetOptionalRecordedVariable(SanitizedStringVar, secret, Secret_String)) + assert.Equal(secret, target.GetOptionalEnvVar(SanitizedStringVar, secret, Secret_String)) // call GetOptionalRecordedVariable with the Secret_Base64String VariableType arg - assert.Equal(secretBase64, target.GetOptionalRecordedVariable(SanitizedBase64StrigVar, secretBase64, Secret_Base64String)) + assert.Equal(secretBase64, target.GetOptionalEnvVar(SanitizedBase64StrigVar, secretBase64, Secret_Base64String)) // Calling Stop will save the variables and apply the sanitization options err = target.Stop() @@ -143,7 +143,7 @@ func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables( target, err := NewRecording(context, Playback) assert.Nil(err) - target.GetOptionalRecordedVariable(expectedVariableName, expectedVariableValue, Default) + target.GetOptionalEnvVar(expectedVariableName, expectedVariableValue, NoSanitization) err = target.Stop() assert.Nil(err) @@ -159,7 +159,7 @@ func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables( assert.Nil(err) // add a new variable to the existing batch - target2.GetOptionalRecordedVariable(addedVariableName, addedVariableValue, Default) + target2.GetOptionalEnvVar(addedVariableName, addedVariableValue, NoSanitization) err = target2.Stop() assert.Nil(err) diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go index c9405d03593a..0e4fa6be1ec6 100644 --- a/sdk/internal/testframework/request_matcher.go +++ b/sdk/internal/testframework/request_matcher.go @@ -16,8 +16,10 @@ import ( ) type RequestMatcher struct { - context TestContext - IgnoredHeaders map[string]*string + context TestContext + // IgnoredHeaders is a map acting as a hash set of the header names that will be ignored for matching. + // Modifying the keys in the map will affect how headers are matched for recordings. + IgnoredHeaders map[string]struct{} bodyMatcher StringMatcher urlMatcher StringMatcher methodMatcher StringMatcher @@ -26,15 +28,15 @@ type RequestMatcher struct { type StringMatcher func(reqVal string, recVal string) bool type matcherWrapper func(matcher StringMatcher, testContext TestContext) bool -var ignoredHeaders = map[string]*string{ - "Date": nil, - "X-Ms-Date": nil, - "x-ms-date": nil, - "x-ms-client-request-id": nil, - "User-Agent": nil, - "Request-Id": nil, - "traceparent": nil, - "Authorization": nil, +var ignoredHeaders = map[string]struct{}{ + "Date": {}, + "X-Ms-Date": {}, + "x-ms-date": {}, + "x-ms-client-request-id": {}, + "User-Agent": {}, + "Request-Id": {}, + "traceparent": {}, + "Authorization": {}, } const ( @@ -46,25 +48,27 @@ const ( bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" ) -func DefaultMatcher(testContext TestContext) *RequestMatcher { +// defaultMatcher returns a new RequestMatcher configured with the default matching behavior. +func defaultMatcher(testContext TestContext) *RequestMatcher { // The default sanitizer sanitizes the Authorization header matcher := &RequestMatcher{ context: testContext, IgnoredHeaders: ignoredHeaders, } matcher.SetBodyMatcher(func(req string, rec string) bool { - return DefaultStringMatcher(req, rec) + return defaultStringMatcher(req, rec) }) matcher.SetURLMatcher(func(req string, rec string) bool { - return DefaultStringMatcher(req, rec) + return defaultStringMatcher(req, rec) }) matcher.SetMethodMatcher(func(req string, rec string) bool { - return DefaultStringMatcher(req, rec) + return defaultStringMatcher(req, rec) }) return matcher } +// SetBodyMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request body payload with the string value of the recorded body payload. func (m *RequestMatcher) SetBodyMatcher(matcher StringMatcher) { m.bodyMatcher = func(reqVal string, recVal string) bool { isMatch := matcher(reqVal, recVal) @@ -75,6 +79,7 @@ func (m *RequestMatcher) SetBodyMatcher(matcher StringMatcher) { } } +// SetURLMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request URL with the string value of the recorded URL func (m *RequestMatcher) SetURLMatcher(matcher StringMatcher) { m.urlMatcher = func(reqVal string, recVal string) bool { isMatch := matcher(reqVal, recVal) @@ -85,6 +90,7 @@ func (m *RequestMatcher) SetURLMatcher(matcher StringMatcher) { } } +// SetMethodMatcher replaces the default matching behavior with a custom StringMatcher that compares the string value of the request method with the string value of the recorded method func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) { m.methodMatcher = func(reqVal string, recVal string) bool { isMatch := matcher(reqVal, recVal) @@ -95,7 +101,7 @@ func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) { } } -func DefaultStringMatcher(s1 string, s2 string) bool { +func defaultStringMatcher(s1 string, s2 string) bool { return s1 == s2 } diff --git a/sdk/internal/testframework/request_matcher_test.go b/sdk/internal/testframework/request_matcher_test.go index f85e7110ee22..1dd07f5c006a 100644 --- a/sdk/internal/testframework/request_matcher_test.go +++ b/sdk/internal/testframework/request_matcher_test.go @@ -33,7 +33,7 @@ const unMatchedBody string = "This body does not match." func (s *requestMatcherTests) TestCompareBodies() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) req := http.Request{Body: closerFromString(matchedBody)} recReq := cassette.Request{Body: matchedBody} @@ -53,7 +53,7 @@ func (s *requestMatcherTests) TestCompareBodies() { func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -73,7 +73,7 @@ func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -93,7 +93,7 @@ func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -117,7 +117,7 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -141,7 +141,7 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -171,7 +171,7 @@ func (s *requestMatcherTests) TestCompareURLs() { host := "foo.bar" req := http.Request{URL: &url.URL{Scheme: scheme, Host: host}} recReq := cassette.Request{URL: scheme + "://" + host} - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) assert.True(matcher.compareURLs(&req, recReq.URL)) @@ -187,7 +187,7 @@ func (s *requestMatcherTests) TestCompareMethods() { methodPatch := "PATCH" req := http.Request{Method: methodPost} recReq := cassette.Request{Method: methodPost} - matcher := DefaultMatcher(context) + matcher := defaultMatcher(context) assert.True(matcher.compareMethods(&req, recReq.Method)) From 5549fb597cf457d3a8d9b35a599114f22dda2031 Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Tue, 15 Jun 2021 11:29:48 -0500 Subject: [PATCH 08/11] fb --- sdk/internal/testframework/recording_sanitizer.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/internal/testframework/recording_sanitizer.go b/sdk/internal/testframework/recording_sanitizer.go index 16d728688d83..fefb1b4f6a81 100644 --- a/sdk/internal/testframework/recording_sanitizer.go +++ b/sdk/internal/testframework/recording_sanitizer.go @@ -14,7 +14,7 @@ import ( type RecordingSanitizer struct { recorder *recorder.Recorder - headersToSanitize map[string]*string + headersToSanitize []string urlSanitizer StringSanitizer bodySanitizer StringSanitizer } @@ -30,7 +30,7 @@ var sanitizedValueSlice = []string{SanitizedValue} // To customize sanitization, call AddSanitizedHeaders, AddBodySanitizer, or AddUrlSanitizer. func defaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { // The default sanitizer sanitizes the Authorization header - s := &RecordingSanitizer{headersToSanitize: map[string]*string{"Authorization": nil}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} + s := &RecordingSanitizer{headersToSanitize: []string{"Authorization"}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} recorder.AddSaveFilter(s.applySaveFilter) return s @@ -39,7 +39,7 @@ func defaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { // AddSanitizedHeaders adds the supplied header names to the list of headers to be sanitized on request and response recordings. func (s *RecordingSanitizer) AddSanitizedHeaders(headers ...string) { for _, headerName := range headers { - s.headersToSanitize[headerName] = nil + s.headersToSanitize = append(s.headersToSanitize, headerName) } } @@ -54,7 +54,7 @@ func (s *RecordingSanitizer) AddUrlSanitizer(sanitizer StringSanitizer) { } func (s *RecordingSanitizer) sanitizeHeaders(header http.Header) { - for headerName := range s.headersToSanitize { + for _, headerName := range s.headersToSanitize { if _, ok := header[headerName]; ok { header[headerName] = sanitizedValueSlice } From e033ef643a142ca3de504c6d03e3de4c2815441a Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Tue, 15 Jun 2021 11:37:16 -0500 Subject: [PATCH 09/11] fb --- sdk/internal/testframework/recording_sanitizer.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdk/internal/testframework/recording_sanitizer.go b/sdk/internal/testframework/recording_sanitizer.go index fefb1b4f6a81..2f14bde25298 100644 --- a/sdk/internal/testframework/recording_sanitizer.go +++ b/sdk/internal/testframework/recording_sanitizer.go @@ -19,9 +19,13 @@ type RecordingSanitizer struct { bodySanitizer StringSanitizer } +// StringSanitizer is a func that will modify the string pointed to by the parameter into a sanitized value. type StringSanitizer func(*string) +// SanitizedValue is the default placeholder value to be used for sanitized strings. const SanitizedValue string = "sanitized" + +// SanitizedBase64Value is the default placeholder value to be used for sanitized base-64 encoded strings. const SanitizedBase64Value string = "Kg==" var sanitizedValueSlice = []string{SanitizedValue} From e431fc19e953738b02baa253c0125acb94f6179e Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Tue, 22 Jun 2021 12:19:54 -0500 Subject: [PATCH 10/11] rename testframework to recording --- sdk/internal/testframework/recording.go | 4 ++-- sdk/internal/testframework/recording_test.go | 2 +- sdk/internal/testframework/request_matcher.go | 2 +- .../testframework/request_matcher_test.go | 2 +- .../{recording_sanitizer.go => sanitizer.go} | 22 +++++++++---------- ...ng_sanitizer_test.go => sanitizer_test.go} | 14 ++++++------ sdk/internal/testframework/testcontext.go | 2 +- 7 files changed, 24 insertions(+), 24 deletions(-) rename sdk/internal/testframework/{recording_sanitizer.go => sanitizer.go} (73%) rename sdk/internal/testframework/{recording_sanitizer_test.go => sanitizer_test.go} (91%) diff --git a/sdk/internal/testframework/recording.go b/sdk/internal/testframework/recording.go index 9a74549a1568..c9a4110ac6dc 100644 --- a/sdk/internal/testframework/recording.go +++ b/sdk/internal/testframework/recording.go @@ -3,7 +3,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package testframework +package recording import ( "errors" @@ -33,7 +33,7 @@ type Recording struct { recorder *recorder.Recorder src rand.Source now *time.Time - Sanitizer *RecordingSanitizer + Sanitizer *Sanitizer Matcher *RequestMatcher c TestContext } diff --git a/sdk/internal/testframework/recording_test.go b/sdk/internal/testframework/recording_test.go index 15e4ea2e4b04..41c52e141560 100644 --- a/sdk/internal/testframework/recording_test.go +++ b/sdk/internal/testframework/recording_test.go @@ -3,7 +3,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package testframework +package recording import ( "fmt" diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/testframework/request_matcher.go index 0e4fa6be1ec6..fafff7eee83c 100644 --- a/sdk/internal/testframework/request_matcher.go +++ b/sdk/internal/testframework/request_matcher.go @@ -3,7 +3,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package testframework +package recording import ( "bytes" diff --git a/sdk/internal/testframework/request_matcher_test.go b/sdk/internal/testframework/request_matcher_test.go index 1dd07f5c006a..d9027c2fce59 100644 --- a/sdk/internal/testframework/request_matcher_test.go +++ b/sdk/internal/testframework/request_matcher_test.go @@ -3,7 +3,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package testframework +package recording import ( "io" diff --git a/sdk/internal/testframework/recording_sanitizer.go b/sdk/internal/testframework/sanitizer.go similarity index 73% rename from sdk/internal/testframework/recording_sanitizer.go rename to sdk/internal/testframework/sanitizer.go index 2f14bde25298..c53c90edf899 100644 --- a/sdk/internal/testframework/recording_sanitizer.go +++ b/sdk/internal/testframework/sanitizer.go @@ -3,7 +3,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package testframework +package recording import ( "net/http" @@ -12,7 +12,7 @@ import ( "github.com/dnaeon/go-vcr/recorder" ) -type RecordingSanitizer struct { +type Sanitizer struct { recorder *recorder.Recorder headersToSanitize []string urlSanitizer StringSanitizer @@ -32,32 +32,32 @@ var sanitizedValueSlice = []string{SanitizedValue} // defaultSanitizer returns a new RecordingSanitizer with the default sanitizing behavior. // To customize sanitization, call AddSanitizedHeaders, AddBodySanitizer, or AddUrlSanitizer. -func defaultSanitizer(recorder *recorder.Recorder) *RecordingSanitizer { +func defaultSanitizer(recorder *recorder.Recorder) *Sanitizer { // The default sanitizer sanitizes the Authorization header - s := &RecordingSanitizer{headersToSanitize: []string{"Authorization"}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} + s := &Sanitizer{headersToSanitize: []string{"Authorization"}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} recorder.AddSaveFilter(s.applySaveFilter) return s } // AddSanitizedHeaders adds the supplied header names to the list of headers to be sanitized on request and response recordings. -func (s *RecordingSanitizer) AddSanitizedHeaders(headers ...string) { +func (s *Sanitizer) AddSanitizedHeaders(headers ...string) { for _, headerName := range headers { s.headersToSanitize = append(s.headersToSanitize, headerName) } } // AddBodysanitizer configures the supplied StringSanitizer to sanitize recording request and response bodies -func (s *RecordingSanitizer) AddBodysanitizer(sanitizer StringSanitizer) { +func (s *Sanitizer) AddBodysanitizer(sanitizer StringSanitizer) { s.bodySanitizer = sanitizer } // AddUriSanitizer configures the supplied StringSanitizer to sanitize recording request and response URLs -func (s *RecordingSanitizer) AddUrlSanitizer(sanitizer StringSanitizer) { +func (s *Sanitizer) AddUrlSanitizer(sanitizer StringSanitizer) { s.urlSanitizer = sanitizer } -func (s *RecordingSanitizer) sanitizeHeaders(header http.Header) { +func (s *Sanitizer) sanitizeHeaders(header http.Header) { for _, headerName := range s.headersToSanitize { if _, ok := header[headerName]; ok { header[headerName] = sanitizedValueSlice @@ -65,15 +65,15 @@ func (s *RecordingSanitizer) sanitizeHeaders(header http.Header) { } } -func (s *RecordingSanitizer) sanitizeBodies(body *string) { +func (s *Sanitizer) sanitizeBodies(body *string) { s.bodySanitizer(body) } -func (s *RecordingSanitizer) sanitizeURL(url *string) { +func (s *Sanitizer) sanitizeURL(url *string) { s.urlSanitizer(url) } -func (s *RecordingSanitizer) applySaveFilter(i *cassette.Interaction) error { +func (s *Sanitizer) applySaveFilter(i *cassette.Interaction) error { s.sanitizeHeaders(i.Request.Headers) s.sanitizeHeaders(i.Response.Headers) s.sanitizeURL(&i.Request.URL) diff --git a/sdk/internal/testframework/recording_sanitizer_test.go b/sdk/internal/testframework/sanitizer_test.go similarity index 91% rename from sdk/internal/testframework/recording_sanitizer_test.go rename to sdk/internal/testframework/sanitizer_test.go index 23adad100ea4..8b57be2027ef 100644 --- a/sdk/internal/testframework/recording_sanitizer_test.go +++ b/sdk/internal/testframework/sanitizer_test.go @@ -3,7 +3,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package testframework +package recording import ( "net/http" @@ -18,7 +18,7 @@ import ( "github.com/stretchr/testify/suite" ) -type recordingSanitizerTests struct { +type sanitizerTests struct { suite.Suite } @@ -28,10 +28,10 @@ const customHeader2 string = "Barheader" const nonSanitizedHeader string = "notsanitized" func TestRecordingSanitizer(t *testing.T) { - suite.Run(t, new(recordingSanitizerTests)) + suite.Run(t, new(sanitizerTests)) } -func (s *recordingSanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() { +func (s *sanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() { assert := assert.New(s.T()) server, cleanup := mock.NewServer() server.SetResponse() @@ -57,7 +57,7 @@ func (s *recordingSanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() { } } -func (s *recordingSanitizerTests) TestAddSanitizedHeadersSanitizes() { +func (s *sanitizerTests) TestAddSanitizedHeadersSanitizes() { assert := assert.New(s.T()) server, cleanup := mock.NewServer() server.SetResponse() @@ -91,7 +91,7 @@ func (s *recordingSanitizerTests) TestAddSanitizedHeadersSanitizes() { } } -func (s *recordingSanitizerTests) TestAddUrlSanitizerSanitizes() { +func (s *sanitizerTests) TestAddUrlSanitizerSanitizes() { assert := assert.New(s.T()) secret := "secretvalue" secretBody := "some body content that contains a " + secret @@ -129,7 +129,7 @@ func (s *recordingSanitizerTests) TestAddUrlSanitizerSanitizes() { } } -func (s *recordingSanitizerTests) TearDownSuite() { +func (s *sanitizerTests) TearDownSuite() { assert := assert.New(s.T()) // cleanup test files err := os.RemoveAll("testfiles") diff --git a/sdk/internal/testframework/testcontext.go b/sdk/internal/testframework/testcontext.go index 97bcc132e5ef..a119344b07ac 100644 --- a/sdk/internal/testframework/testcontext.go +++ b/sdk/internal/testframework/testcontext.go @@ -3,7 +3,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -package testframework +package recording type TestContext interface { Fail(string) From a23cf6a667af2400c04fe52117df9e1b9ec49f0c Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Tue, 22 Jun 2021 12:48:43 -0500 Subject: [PATCH 11/11] dir rename --- sdk/internal/{testframework => recording}/README.md | 0 sdk/internal/{testframework => recording}/recording.go | 0 sdk/internal/{testframework => recording}/recording_test.go | 0 sdk/internal/{testframework => recording}/request_matcher.go | 0 sdk/internal/{testframework => recording}/request_matcher_test.go | 0 sdk/internal/{testframework => recording}/sanitizer.go | 0 sdk/internal/{testframework => recording}/sanitizer_test.go | 0 sdk/internal/{testframework => recording}/testcontext.go | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename sdk/internal/{testframework => recording}/README.md (100%) rename sdk/internal/{testframework => recording}/recording.go (100%) rename sdk/internal/{testframework => recording}/recording_test.go (100%) rename sdk/internal/{testframework => recording}/request_matcher.go (100%) rename sdk/internal/{testframework => recording}/request_matcher_test.go (100%) rename sdk/internal/{testframework => recording}/sanitizer.go (100%) rename sdk/internal/{testframework => recording}/sanitizer_test.go (100%) rename sdk/internal/{testframework => recording}/testcontext.go (100%) diff --git a/sdk/internal/testframework/README.md b/sdk/internal/recording/README.md similarity index 100% rename from sdk/internal/testframework/README.md rename to sdk/internal/recording/README.md diff --git a/sdk/internal/testframework/recording.go b/sdk/internal/recording/recording.go similarity index 100% rename from sdk/internal/testframework/recording.go rename to sdk/internal/recording/recording.go diff --git a/sdk/internal/testframework/recording_test.go b/sdk/internal/recording/recording_test.go similarity index 100% rename from sdk/internal/testframework/recording_test.go rename to sdk/internal/recording/recording_test.go diff --git a/sdk/internal/testframework/request_matcher.go b/sdk/internal/recording/request_matcher.go similarity index 100% rename from sdk/internal/testframework/request_matcher.go rename to sdk/internal/recording/request_matcher.go diff --git a/sdk/internal/testframework/request_matcher_test.go b/sdk/internal/recording/request_matcher_test.go similarity index 100% rename from sdk/internal/testframework/request_matcher_test.go rename to sdk/internal/recording/request_matcher_test.go diff --git a/sdk/internal/testframework/sanitizer.go b/sdk/internal/recording/sanitizer.go similarity index 100% rename from sdk/internal/testframework/sanitizer.go rename to sdk/internal/recording/sanitizer.go diff --git a/sdk/internal/testframework/sanitizer_test.go b/sdk/internal/recording/sanitizer_test.go similarity index 100% rename from sdk/internal/testframework/sanitizer_test.go rename to sdk/internal/recording/sanitizer_test.go diff --git a/sdk/internal/testframework/testcontext.go b/sdk/internal/recording/testcontext.go similarity index 100% rename from sdk/internal/testframework/testcontext.go rename to sdk/internal/recording/testcontext.go