Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 33 additions & 45 deletions sdk/resource_locator.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,59 +271,47 @@ func (rl ResourceLocator) getLength() uint16 {
return uint16(1 /* protocol byte */ + 1 /* length byte */ + len(rl.body) + len(rl.identifier))
}

// setURL - Store a fully qualified protocol+body string into a ResourceLocator as a protocol value and a body string
// setURLWithIdentifier - Store a fully qualified protocol+body string and an identifier into a ResourceLocator.
func (rl *ResourceLocator) setURLWithIdentifier(url string, identifier string) error {
if identifier == "" {
return errors.New("identifier is empty")
}
lowerURL := strings.ToLower(url)

if strings.HasPrefix(lowerURL, kPrefixHTTPS) {
urlBody := url[len(kPrefixHTTPS):]
if len(urlBody) > kMaxBodyLen {
return errors.New("URL too long")
}
identifierLen := len(identifier)
switch {
case identifierLen == 0:
rl.protocol = urlProtocolHTTPS | identifierNone
case identifierLen >= 1 && identifierLen <= 2:
rl.protocol = urlProtocolHTTPS | identifier2Byte
case identifierLen >= 3 && identifierLen <= 8:
rl.protocol = urlProtocolHTTPS | identifier8Byte
case identifierLen >= 9 && identifierLen <= 32:
rl.protocol = urlProtocolHTTPS | identifier32Byte
default:
return fmt.Errorf("unsupported identifier length: %d", identifierLen)
}
rl.body = urlBody
rl.identifier = identifier
return nil
return rl.setURLParts(url[len(kPrefixHTTPS):], identifier, urlProtocolHTTPS)
}
if strings.HasPrefix(lowerURL, kPrefixHTTP) {
urlBody := url[len(kPrefixHTTP):]
if len(urlBody) > kMaxBodyLen {
return errors.New("URL too long")
}
identifierLen := len(identifier)
padding := ""
switch {
case identifierLen == 0:
rl.protocol = urlProtocolHTTP | identifierNone
case identifierLen >= 1 && identifierLen <= identifier2ByteLength:
padding = strings.Repeat("\x00", identifier2ByteLength-identifierLen)
rl.protocol = urlProtocolHTTP | identifier2Byte
case identifierLen >= 3 && identifierLen <= identifier8ByteLength:
padding = strings.Repeat("\x00", identifier8ByteLength-identifierLen)
rl.protocol = urlProtocolHTTP | identifier8Byte
case identifierLen >= 9 && identifierLen <= identifier32ByteLength:
padding = strings.Repeat("\x00", identifier32ByteLength-identifierLen)
rl.protocol = urlProtocolHTTP | identifier32Byte
default:
return fmt.Errorf("unsupported identifier length: %d", identifierLen)
}
rl.body = urlBody
rl.identifier = identifier + padding
return nil
return rl.setURLParts(url[len(kPrefixHTTP):], identifier, urlProtocolHTTP)
}
return errors.New("unsupported protocol with identifier: " + url)
}

func (rl *ResourceLocator) setURLParts(urlBody, identifier string, baseProtocol protocolHeader) error {
if len(urlBody) > kMaxBodyLen {
return errors.New("URL too long")
}

identifierLen := len(identifier)
var idProtocol protocolHeader
var paddingLen int

switch {
case identifierLen <= identifier2ByteLength:
idProtocol = identifier2Byte
paddingLen = identifier2ByteLength - identifierLen
case identifierLen > identifier2ByteLength && identifierLen <= identifier8ByteLength:
idProtocol = identifier8Byte
paddingLen = identifier8ByteLength - identifierLen
case identifierLen > identifier8ByteLength && identifierLen <= identifier32ByteLength:
idProtocol = identifier32Byte
paddingLen = identifier32ByteLength - identifierLen
default:
return fmt.Errorf("unsupported identifier length: %d", identifierLen)
}

rl.protocol = baseProtocol | idProtocol
rl.body = urlBody
rl.identifier = identifier + strings.Repeat("\x00", paddingLen)
return nil
}
105 changes: 105 additions & 0 deletions sdk/resource_locator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,111 @@ func TestReadResourceLocator(t *testing.T) {
}
}

func TestURLWithIdentifier(t *testing.T) {
tests := []struct {
name string
url string
identifier string
expectedErr bool
expectedProtocol protocolHeader
expectedBody string
expectedIdentifier string
}{
{
name: "HTTPS URL with 18-byte identifier",
url: "https://example.com",
identifier: "aws-kms-asymmetric",
expectedErr: false,
expectedProtocol: urlProtocolHTTPS | identifier32Byte,
expectedBody: "example.com",
expectedIdentifier: "aws-kms-asymmetric\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
},
{
name: "HTTP URL with 8-byte identifier",
url: "http://example.com",
identifier: "id123456",
expectedErr: false,
expectedProtocol: urlProtocolHTTP | identifier8Byte,
expectedBody: "example.com",
expectedIdentifier: "id123456",
},
{
name: "HTTPS URL with 2-byte identifier",
url: "https://example.com",
identifier: "i1",
expectedErr: false,
expectedProtocol: urlProtocolHTTPS | identifier2Byte,
expectedBody: "example.com",
expectedIdentifier: "i1",
},
{
name: "HTTP URL with 6-byte identifier",
url: "http://example.com",
identifier: "id1234",
expectedErr: false,
expectedProtocol: urlProtocolHTTP | identifier8Byte,
expectedBody: "example.com",
expectedIdentifier: "id1234\x00\x00",
},
{
name: "HTTPS URL with 32-byte identifier",
url: "https://long.url.for.testing.com/path",
identifier: "12345678901234567890123456789012",
expectedErr: false,
expectedProtocol: urlProtocolHTTPS | identifier32Byte,
expectedBody: "long.url.for.testing.com/path",
expectedIdentifier: "12345678901234567890123456789012",
},
{
name: "Unsupported protocol should error",
url: "ftp://example.com",
identifier: "id1",
expectedErr: true,
expectedProtocol: 0,
expectedBody: "",
expectedIdentifier: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rl := &ResourceLocator{}
err := rl.setURLWithIdentifier(tt.url, tt.identifier)

if tt.expectedErr {
if err == nil {
t.Fatal("expected an error but got none")
}
return
}
if err != nil {
t.Fatalf("setURLWithIdentifier() unexpected error: %v", err)
}

var buf bytes.Buffer
err = rl.writeResourceLocator(&buf)
if err != nil {
t.Fatalf("writeResourceLocator() unexpected error: %v", err)
}

parsedRl, err := NewResourceLocatorFromReader(&buf)
if err != nil {
t.Fatalf("NewResourceLocatorFromReader() unexpected error: %v", err)
}

if tt.expectedProtocol != parsedRl.protocol {
t.Fatalf("expected protocol %v, got %v", tt.expectedProtocol, parsedRl.protocol)
}
if tt.expectedBody != parsedRl.body {
t.Fatalf("expected body %q, got %q", tt.expectedBody, parsedRl.body)
}
if tt.expectedIdentifier != parsedRl.identifier {
t.Fatalf("expected identifier %q, got %q", tt.expectedIdentifier, parsedRl.identifier)
}
})
}
}

func TestGetIdentifier(t *testing.T) {
tests := []struct {
n string
Expand Down
Loading