Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to advertise updated OSType while registering ECS Windows instances #2859

Merged
merged 7 commits into from
May 6, 2021
3 changes: 2 additions & 1 deletion agent/api/ecsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,11 @@ func validateRegisteredAttributes(expectedAttributes, actualAttributes []*ecs.At
}

func (client *APIECSClient) getAdditionalAttributes() []*ecs.Attribute {
osFamily := config.GetOperatingSystemFamily()
attrs := []*ecs.Attribute{
{
Name: aws.String(osTypeAttrName),
Value: aws.String(config.OSType),
Value: aws.String(osFamily),
},
}
// Send cpu arch attribute directly when running on external capacity. When running on EC2, this is not needed
Expand Down
10 changes: 5 additions & 5 deletions agent/api/ecsclient/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func TestRegisterContainerInstance(t *testing.T) {

fakeCapabilities := []string{"capability1", "capability2"}
expectedAttributes := map[string]string{
"ecs.os-type": config.OSType,
"ecs.os-type": config.GetOperatingSystemFamily(),
"my_custom_attribute": "Custom_Value1",
"my_other_custom_attribute": "Custom_Value2",
"ecs.availability-zone": "us-west-2b",
Expand Down Expand Up @@ -480,7 +480,7 @@ func TestReRegisterContainerInstance(t *testing.T) {

fakeCapabilities := []string{"capability1", "capability2"}
expectedAttributes := map[string]string{
"ecs.os-type": config.OSType,
"ecs.os-type": config.GetOperatingSystemFamily(),
"ecs.availability-zone": "us-west-2b",
"ecs.outpost-arn": "test:arn:outpost",
}
Expand Down Expand Up @@ -570,7 +570,7 @@ func TestRegisterContainerInstanceWithEmptyTags(t *testing.T) {
client, mc, _ := NewMockClient(mockCtrl, mockEC2Metadata, nil)

expectedAttributes := map[string]string{
"ecs.os-type": config.OSType,
"ecs.os-type": config.GetOperatingSystemFamily(),
"my_custom_attribute": "Custom_Value1",
"my_other_custom_attribute": "Custom_Value2",
}
Expand Down Expand Up @@ -637,7 +637,7 @@ func TestRegisterBlankCluster(t *testing.T) {
client.(*APIECSClient).SetSDK(mc)

expectedAttributes := map[string]string{
"ecs.os-type": config.OSType,
"ecs.os-type": config.GetOperatingSystemFamily(),
}
defaultCluster := config.DefaultClusterName
gomock.InOrder(
Expand Down Expand Up @@ -693,7 +693,7 @@ func TestRegisterBlankClusterNotCreatingClusterWhenErrorNotClusterNotFound(t *te
client.(*APIECSClient).SetSDK(mc)

expectedAttributes := map[string]string{
"ecs.os-type": config.OSType,
"ecs.os-type": config.GetOperatingSystemFamily(),
}

gomock.InOrder(
Expand Down
20 changes: 19 additions & 1 deletion agent/config/const_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,23 @@

package config

// OSType is the type of operating system where agent is running
import "golang.org/x/sys/windows/registry"

const OSType = "windows"
const (
// envSkipDomainJoinCheck is an environment setting that can be used to skip
// domain join check validation. This is useful for integration and
// functional-tests but should not be set for any non-test use-case.
envSkipDomainJoinCheck = "ZZZ_SKIP_DOMAIN_JOIN_CHECK_NOT_SUPPORTED_IN_PRODUCTION"
releaseId2004SAC = "2004"
releaseId1909SAC = "1909"
windowsServer2019 = "Windows Server 2019"
windowsServer2016 = "Windows Server 2016"
windowsServerDataCenter = "Windows Server Datacenter"
installationTypeCore = "Server Core"
installationTypeFull = "Server"
unsupportedWindowsOS = "windows"
osTypeFormat = "WINDOWS_SERVER_%s_%s"
ecsWinRegistryRootKey = registry.LOCAL_MACHINE
ecsWinRegistryRootPath = `SOFTWARE\Microsoft\Windows NT\CurrentVersion`
)
5 changes: 5 additions & 0 deletions agent/config/parse_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ func parseGMSACapability() bool {
func parseFSxWindowsFileServerCapability() bool {
return false
}

// GetOperatingSystemFamily() returns "linux" as operating system family for linux based ecs instances
func GetOperatingSystemFamily() string {
return OSType
}
5 changes: 5 additions & 0 deletions agent/config/parse_unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,8 @@ func parseGMSACapability() bool {
func parseFSxWindowsFileServerCapability() bool {
return false
}

// GetOperatingSystemFamily() returns "unsupported" as operating system family for non windows and non linux based ecs instances
func GetOperatingSystemFamily() string {
return OSType
}
107 changes: 98 additions & 9 deletions agent/config/parse_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,111 @@
package config

import (
"fmt"
"os"
"strings"
"syscall"
"unsafe"

"github.com/pkg/errors"

"golang.org/x/sys/windows/registry"

"github.com/aws/amazon-ecs-agent/agent/statemanager/dependencies"
"github.com/aws/amazon-ecs-agent/agent/utils"
"github.com/cihub/seelog"
)

const (
// envSkipDomainJoinCheck is an environment setting that can be used to skip
// domain join check validation. This is useful for integration and
// functional-tests but should not be set for any non-test use-case.
envSkipDomainJoinCheck = "ZZZ_SKIP_DOMAIN_JOIN_CHECK_NOT_SUPPORTED_IN_PRODUCTION"
)
var winRegistry dependencies.WindowsRegistry

func init() {
winRegistry = dependencies.StdRegistry{}
}

func setWinRegistry(mockRegistry dependencies.WindowsRegistry) {
winRegistry = mockRegistry
}

func getInstallationType(installationType string) (string, error) {
switch installationType {
case installationTypeFull:
return "FULL", nil
case installationTypeCore:
return "CORE", nil
default:
return "", errors.Errorf("unsupported Installation type:%s", installationType)
}
}

func getReleaseIdForSACReleases(productName string) (string, error) {
if strings.HasPrefix(productName, windowsServer2019) {
return "2019", nil
} else if strings.HasPrefix(productName, windowsServer2016) {
return "2016", nil
}
err := seelog.Errorf("unsupported productName:%s for Windows SAC Release", productName)
return "", err
}

func getReleaseIdForLTSCReleases(releaseId string) (string, error) {
switch releaseId {
case releaseId2004SAC:
return releaseId2004SAC, nil
case releaseId1909SAC:
return releaseId1909SAC, nil
default:
return "", errors.Errorf("unsupported ReleaseId:%s for Windows LTSC Release", releaseId)
}
}

// GetOperatingSystemFamily() reads the NT current version from windows registry and constructs operating system family string
// In case of any exception this method just returns "windows" as operating system type.
func GetOperatingSystemFamily() string {
key, err := winRegistry.OpenKey(ecsWinRegistryRootKey, ecsWinRegistryRootPath, registry.QUERY_VALUE)
if err != nil {
seelog.Errorf("Unable to open Windows registry key to determine Windows version: %v", err)
return unsupportedWindowsOS
}
defer key.Close()

productName, _, err := key.GetStringValue("ProductName")
if err != nil {
seelog.Errorf("Unable to read registry key, ProductName: %v", err)
return unsupportedWindowsOS
}
installationType, _, err := key.GetStringValue("InstallationType")
if err != nil {
seelog.Errorf("Unable to read registry key, InstallationType: %v", err)
return unsupportedWindowsOS
}
iType, err := getInstallationType(installationType)
if err != nil {
seelog.Errorf("Invalid Installation type found: %v", err)
return unsupportedWindowsOS
}

releaseId := ""
if strings.HasPrefix(productName, windowsServerDataCenter) {
releaseIdFromRegistry, _, err := key.GetStringValue("ReleaseId")
if err != nil {
seelog.Errorf("Unable to read registry key, ReleaseId: %v", err)
return unsupportedWindowsOS
}

releaseId, err = getReleaseIdForLTSCReleases(releaseIdFromRegistry)
if err != nil {
seelog.Errorf("Failed to construct releaseId for Windows LTSC, Error: %v", err)
return unsupportedWindowsOS
}
} else {
releaseId, err = getReleaseIdForSACReleases(productName)
if err != nil {
seelog.Errorf("Failed to construct releaseId for Windows SAC, Error: %v", err)
return unsupportedWindowsOS
}
}
return fmt.Sprintf(osTypeFormat, releaseId, iType)
}

// parseGMSACapability is used to determine if gMSA support can be enabled
func parseGMSACapability() bool {
Expand Down Expand Up @@ -94,16 +182,17 @@ func isDomainJoined() (bool, error) {
// isWindows2016 is used to check if container instance is versioned Windows 2016
// Reference: https://godoc.org/golang.org/x/sys/windows/registry
var isWindows2016 = func() (bool, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
key, err := winRegistry.OpenKey(ecsWinRegistryRootKey, ecsWinRegistryRootPath, registry.QUERY_VALUE)

if err != nil {
seelog.Errorf("Unable to open Windows registry key to determine Windows version: %v", err)
seelog.Errorf("unable to open Windows registry key to determine Windows version: %v", err)
return false, err
}
defer key.Close()

version, _, err := key.GetStringValue("ProductName")
if err != nil {
seelog.Errorf("Unable to read current version from Windows registry: %v", err)
seelog.Errorf("unable to read current version from Windows registry: %v", err)
return false, err
}

Expand Down
118 changes: 117 additions & 1 deletion agent/config/parse_windows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
package config

import (
"github.com/stretchr/testify/assert"
"golang.org/x/sys/windows/registry"

"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/golang/mock/gomock"

mock_dependencies "github.com/aws/amazon-ecs-agent/agent/statemanager/dependencies/mocks"
)

func TestParseGMSACapability(t *testing.T) {
Expand Down Expand Up @@ -50,3 +55,114 @@ func TestParseFSxWindowsFileServerCapability(t *testing.T) {

assert.False(t, parseFSxWindowsFileServerCapability())
}

func getMockRegistry(controller *gomock.Controller) *mock_dependencies.MockWindowsRegistry {
winRegistry := mock_dependencies.NewMockWindowsRegistry(controller)
setWinRegistry(winRegistry)
return winRegistry
}

func getMockKey(t *testing.T) *mock_dependencies.MockRegistryKey {
ctrl := gomock.NewController(t)
winRegistry := getMockRegistry(ctrl)
mockKey := mock_dependencies.NewMockRegistryKey(ctrl)
winRegistry.EXPECT().OpenKey(ecsWinRegistryRootKey, ecsWinRegistryRootPath, gomock.Any()).Return(mockKey, nil)
return mockKey
}

func TestGetOperatingSystemFamilyForWS2019Core(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server 2019 Datacenter`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server Core`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("ReleaseId").Return(`1809`, uint32(0), nil)
mockKey.EXPECT().Close()
assert.Equal(t, "WINDOWS_SERVER_2019_CORE", GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForWS2019Full(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server 2019 Datacenter`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("ReleaseId").Return(`1809`, uint32(0), nil)
mockKey.EXPECT().Close()

assert.Equal(t, "WINDOWS_SERVER_2019_FULL", GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForWS2016Full(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server 2016 Datacenter`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("ReleaseId").Return(`1607`, uint32(0), nil)
mockKey.EXPECT().Close()

assert.Equal(t, "WINDOWS_SERVER_2016_FULL", GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForWS2004Core(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server Datacenter`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server Core`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("ReleaseId").Return(`2004`, uint32(0), nil)
mockKey.EXPECT().Close()

assert.Equal(t, "WINDOWS_SERVER_2004_CORE", GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForKeyError(t *testing.T) {
mockKey := getMockKey(t)
winRegistry := getMockRegistry(gomock.NewController(t))
winRegistry.EXPECT().OpenKey(ecsWinRegistryRootKey, ecsWinRegistryRootPath, gomock.Any()).Return(mockKey, registry.ErrNotExist)
mockKey.EXPECT().Close()
assert.Equal(t, unsupportedWindowsOS, GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForProductNameNotExistError(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return("", uint32(0), registry.ErrNotExist)

mockKey.EXPECT().Close()
assert.Equal(t, unsupportedWindowsOS, GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForInstallationTypeNotExistError(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server Datacenter`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server Core`, uint32(0), registry.ErrNotExist)
mockKey.EXPECT().Close()
assert.Equal(t, unsupportedWindowsOS, GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForInvalidInstallationType(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server Datacenter`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server Core Invalid`, uint32(0), nil)
mockKey.EXPECT().Close()
assert.Equal(t, unsupportedWindowsOS, GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForReleaseIdNotExistError(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server Datacenter`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server Core`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("ReleaseId").Return(`2004`, uint32(0), registry.ErrNotExist)
mockKey.EXPECT().Close()
assert.Equal(t, unsupportedWindowsOS, GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForInvalidLTSCReleaseId(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server Datacenter`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server Core`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("ReleaseId").Return(`2028`, uint32(0), registry.ErrNotExist)
mockKey.EXPECT().Close()
assert.Equal(t, unsupportedWindowsOS, GetOperatingSystemFamily())
}

func TestGetOperatingSystemFamilyForInvalidSACReleaseId(t *testing.T) {
mockKey := getMockKey(t)
mockKey.EXPECT().GetStringValue("ProductName").Return(`Windows Server BadVersion`, uint32(0), nil)
mockKey.EXPECT().GetStringValue("InstallationType").Return(`Server Core`, uint32(0), nil)
mockKey.EXPECT().Close()
assert.Equal(t, unsupportedWindowsOS, GetOperatingSystemFamily())
}