Skip to content
20 changes: 10 additions & 10 deletions sdk/tables/aztable/shared_policy_shared_key_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,15 @@ func (c *SharedKeyCredential) SetAccountKey(accountKey string) error {
}

// computeHMACSHA256 generates a hash signature for an HTTP request or for a SAS.
func (c *SharedKeyCredential) ComputeHMACSHA256(message string) (base64String string) {
func (c *SharedKeyCredential) ComputeHMACSHA256(message string) (string, error) {
h := hmac.New(sha256.New, c.accountKey.Load().([]byte))
h.Write([]byte(message))
return base64.StdEncoding.EncodeToString(h.Sum(nil))
_, err := h.Write([]byte(message))
return base64.StdEncoding.EncodeToString(h.Sum(nil)), err
}

func (c *SharedKeyCredential) buildStringToSign(req *http.Request) (string, error) {
// https://docs.microsoft.com/en-us/rest/api/storageservices/authentication-for-the-azure-storage-services
headers := req.Header
contentLength := headers.Get(azcore.HeaderContentLength)
if contentLength == "0" {
contentLength = ""
}

canonicalizedResource, err := c.buildCanonicalizedResource(req.URL)
if err != nil {
Expand All @@ -79,6 +75,7 @@ func (c *SharedKeyCredential) buildStringToSign(req *http.Request) (string, erro
return stringToSign, nil
}

//nolint
func (c *SharedKeyCredential) buildCanonicalizedHeader(headers http.Header) string {
cm := map[string][]string{}
for k, v := range headers {
Expand All @@ -105,7 +102,7 @@ func (c *SharedKeyCredential) buildCanonicalizedHeader(headers http.Header) stri
ch.WriteRune(':')
ch.WriteString(strings.Join(cm[key], ","))
}
return string(ch.Bytes())
return ch.String()
}

func (c *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, error) {
Expand Down Expand Up @@ -133,7 +130,7 @@ func (c *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, er
//do something here
cr.WriteString("?" + "comp=" + compVal[0])
}
return string(cr.Bytes()), nil
return cr.String(), nil
}

// AuthenticationPolicy implements the Credential interface on SharedKeyCredential.
Expand All @@ -147,7 +144,10 @@ func (c *SharedKeyCredential) AuthenticationPolicy(azcore.AuthenticationPolicyOp
if err != nil {
return nil, err
}
signature := c.ComputeHMACSHA256(stringToSign)
signature, err := c.ComputeHMACSHA256(stringToSign)
if err != nil {
return nil, err
}
authHeader := strings.Join([]string{"SharedKeyLite ", c.AccountName(), ":", signature}, "")
req.Request.Header.Set(azcore.HeaderAuthorization, authHeader)

Expand Down
2 changes: 1 addition & 1 deletion sdk/tables/aztable/table_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (t *TableClient) GetEntity(ctx context.Context, partitionKey string, rowKey
if err != nil {
return resp, err
}
castAndRemoveAnnotations(&resp.Value)
err = castAndRemoveAnnotations(&resp.Value)
return resp, err
}

Expand Down
16 changes: 11 additions & 5 deletions sdk/tables/aztable/table_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package aztable
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"net/http"
"testing"
Expand Down Expand Up @@ -180,7 +181,7 @@ func (s *tableClientLiveTests) TestUpsertEntity() {
assert.Equalf(postMerge[mergeProp], val, "%s property should equal %s", mergeProp, val)
}

func (s *tableClientLiveTests) _TestGetEntity() {
func (s *tableClientLiveTests) TestGetEntity() {
assert := assert.New(s.T())
require := require.New(s.T())
client, delete := s.init(true)
Expand Down Expand Up @@ -233,7 +234,8 @@ func (s *tableClientLiveTests) TestQuerySimpleEntity() {
for pager.NextPage(ctx) {
resp = pager.PageResponse()
models = make([]simpleEntity, len(resp.TableEntityQueryResponse.Value))
resp.TableEntityQueryResponse.AsModels(&models)
err := resp.TableEntityQueryResponse.AsModels(&models)
assert.Nil(err)
assert.Equal(len(resp.TableEntityQueryResponse.Value), expectedCount)
}
resp = pager.PageResponse()
Expand Down Expand Up @@ -442,7 +444,8 @@ func (s *tableClientLiveTests) TestBatchError() {
assert.Equal(error_empty_transaction, err.Error())

// Add the last entity to the table prior to adding it as part of the batch to cause a batch failure.
client.AddEntity(ctx, (*entitiesToCreate)[2])
_, err = client.AddEntity(ctx, (*entitiesToCreate)[2])
assert.Nil(err)

// Add the entities to the batch
for i := 0; i < cap(batch); i++ {
Expand Down Expand Up @@ -498,7 +501,10 @@ func (s *tableClientLiveTests) init(doCreate bool) (*TableClient, func()) {
}
}
return client, func() {
client.Delete(ctx)
_, err := client.Delete(ctx)
if err != nil {
fmt.Printf("Error deleting table. %v\n", err.Error())
}
}
}

Expand All @@ -515,7 +521,7 @@ func getStringFromBody(e *runtime.ResponseError) string {
if err != nil {
return "<emtpy body>"
}
b = ioutil.NopCloser(&body)
_ = ioutil.NopCloser(&body)
}
return body.String()
}
8 changes: 4 additions & 4 deletions sdk/tables/aztable/table_pagers.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (p *tableQueryResponsePager) Err() error {

func castAndRemoveAnnotationsSlice(entities *[]map[string]interface{}) {
for _, e := range *entities {
castAndRemoveAnnotations(&e)
castAndRemoveAnnotations(&e) //nolint:errcheck
}
}

Expand Down Expand Up @@ -212,7 +212,7 @@ func castAndRemoveAnnotations(entity *map[string]interface{}) error {
}
value[valueKey] = i
default:
return errors.New(fmt.Sprintf("unsupported annotation found: %s", k))
return fmt.Errorf("unsupported annotation found: %s", k)
}
// remove the annotation key
delete(value, k)
Expand Down Expand Up @@ -249,7 +249,7 @@ func toOdataAnnotatedDictionary(entity *map[string]interface{}) error {
entMap[k] = time.UTC().Format(ISO8601)
continue
default:
return errors.New(fmt.Sprintf("Invalid struct for entity field '%s' of type '%s'", k, tn))
return fmt.Errorf("Invalid struct for entity field '%s' of type '%s'", k, tn)
}
case reflect.Float32, reflect.Float64:
entMap[odataType(k)] = edmDouble
Expand Down Expand Up @@ -321,7 +321,7 @@ func toMap(ent interface{}) (*map[string]interface{}, error) {
entMap[name] = time.UTC().Format(ISO8601)
continue
default:
return nil, errors.New(fmt.Sprintf("Invalid struct for entity field '%s' of type '%s'", typeOfT.Field(i).Name, tn))
return nil, fmt.Errorf("Invalid struct for entity field '%s' of type '%s'", typeOfT.Field(i).Name, tn)
}
case reflect.Float32, reflect.Float64:
entMap[odataType(name)] = edmDouble
Expand Down
38 changes: 26 additions & 12 deletions sdk/tables/aztable/table_pagers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ import (
"github.com/stretchr/testify/assert"
)

type pagerTests struct{}

func TestCastAndRemoveAnnotations(t *testing.T) {
assert := assert.New(t)

Expand Down Expand Up @@ -75,8 +73,10 @@ func BenchmarkUnMarshal_AsJson_CastAndRemove_Map(b *testing.B) {
bt := []byte(complexPayload)
for i := 0; i < b.N; i++ {
var val = make(map[string]interface{})
json.Unmarshal(bt, &val)
castAndRemoveAnnotations(&val)
err := json.Unmarshal(bt, &val)
assert.Nil(err)
err = castAndRemoveAnnotations(&val)
assert.Nil(err)
assert.Equal("somePartition", val["PartitionKey"])
}
}
Expand All @@ -87,28 +87,41 @@ func BenchmarkUnMarshal_FromMap_Entity(b *testing.B) {
bt := []byte(complexPayload)
for i := 0; i < b.N; i++ {
var val = make(map[string]interface{})
json.Unmarshal(bt, &val)
err := json.Unmarshal(bt, &val)
if err != nil {
panic(err)
}
result := complexEntity{}
err := EntityMapAsModel(val, &result)
err = EntityMapAsModel(val, &result)
assert.Nil(err)
assert.Equal("somePartition", result.PartitionKey)
}
}

func check(e error) {
if e != nil {
panic(e)
}
}

func BenchmarkMarshal_Entity_ToMap_ToOdataDict_Map(b *testing.B) {
ent := createComplexEntity()
for i := 0; i < b.N; i++ {
m, _ := toMap(ent)
toOdataAnnotatedDictionary(m)
json.Marshal(m)
err := toOdataAnnotatedDictionary(m)
check(err)
_, err = json.Marshal(m)
check(err)
}
}

func BenchmarkMarshal_Map_ToOdataDict_Map(b *testing.B) {
ent := createComplexEntityMap()
for i := 0; i < b.N; i++ {
toOdataAnnotatedDictionary(&ent)
json.Marshal(ent)
err := toOdataAnnotatedDictionary(&ent)
check(err)
_, err = json.Marshal(ent)
check(err)
}
}

Expand Down Expand Up @@ -180,11 +193,12 @@ func TestDeserializeFromMap(t *testing.T) {
expected := createComplexEntity()
bt := []byte(complexPayload)
var val = make(map[string]interface{})
json.Unmarshal(bt, &val)
err := json.Unmarshal(bt, &val)
assert.Nil(err)
result := complexEntity{}
// tt := reflect.TypeOf(complexEntity{})
// err := fromMap(tt, getTypeValueMap(tt), &val, reflect.ValueOf(&result).Elem())
err := EntityMapAsModel(val, &result)
err = EntityMapAsModel(val, &result)
assert.Nil(err)
assert.EqualValues(expected, result)
}
Expand Down
6 changes: 2 additions & 4 deletions sdk/tables/aztable/table_service_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ func (t *TableServiceClient) Query(queryOptions *QueryOptions) TableQueryRespons
}

func isCosmosEndpoint(url string) bool {
isCosmosEmulator := strings.Index(url, "localhost") >= 0 && strings.Index(url, "8902") >= 0
return isCosmosEmulator ||
strings.Index(url, CosmosTableDomain) >= 0 ||
strings.Index(url, LegacyCosmosTableDomain) >= 0
isCosmosEmulator := strings.Contains(url, "localhost") && strings.Contains(url, "8902")
return isCosmosEmulator || strings.Contains(url, CosmosTableDomain) || strings.Contains(url, LegacyCosmosTableDomain)
}
22 changes: 18 additions & 4 deletions sdk/tables/aztable/table_service_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,17 @@ func (s *tableServiceClientLiveTests) TestServiceErrors() {
assert := assert.New(s.T())
context := getTestContext(s.T().Name())
tableName, err := getTableName(context)
failIfNotNil(assert, err)

_, err = context.client.Create(ctx, tableName)
defer context.client.Delete(ctx, tableName)
assert.Nil(err)
delete := func() {
_, err := context.client.Delete(ctx, tableName)
if err != nil {
fmt.Printf("Error cleaning up test. %v\n", err.Error())
}
}
defer delete()
failIfNotNil(assert, err)

// Create a duplicate table to produce an error
_, err = context.client.Create(ctx, tableName)
Expand All @@ -53,11 +60,18 @@ func (s *tableServiceClientLiveTests) TestCreateTable() {
assert := assert.New(s.T())
context := getTestContext(s.T().Name())
tableName, err := getTableName(context)
failIfNotNil(assert, err)

resp, err := context.client.Create(ctx, tableName)
defer context.client.Delete(ctx, tableName)
delete := func() {
_, err := context.client.Delete(ctx, tableName)
if err != nil {
fmt.Printf("Error cleaning up test. %v\n", err.Error())
}
}
defer delete()

assert.Nil(err)
failIfNotNil(assert, err)
assert.Equal(*resp.TableResponse.TableName, tableName)
}

Expand Down
Loading