Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions pkg/deviceplugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ func newServer(devType string,
}
}

func (srv *server) GetDevicePluginOptions(ctx context.Context, empty *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
func (srv *server) getDevicePluginOptions() *pluginapi.DevicePluginOptions {
return &pluginapi.DevicePluginOptions{
PreStartRequired: srv.preStartContainer != nil,
GetPreferredAllocationAvailable: false,
}, nil
GetPreferredAllocationAvailable: srv.getPreferredAllocation != nil,
}
}

func (srv *server) GetDevicePluginOptions(ctx context.Context, empty *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
return srv.getDevicePluginOptions(), nil
}

func (srv *server) sendDevices(stream pluginapi.DevicePlugin_ListAndWatchServer) error {
Expand Down Expand Up @@ -241,12 +245,8 @@ func (srv *server) setupAndServe(namespace string, devicePluginPath string, kube
return err
}

options := &pluginapi.DevicePluginOptions{
PreStartRequired: srv.preStartContainer != nil,
}

// Register with Kubelet.
err = registerWithKubelet(kubeletSocket, pluginEndpoint, resourceName, options)
err = srv.registerWithKubelet(kubeletSocket, pluginEndpoint, resourceName)
if err != nil {
return err
}
Expand Down Expand Up @@ -293,7 +293,7 @@ func watchFile(file string) error {
}
}

func registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string, options *pluginapi.DevicePluginOptions) error {
func (srv *server) registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string) error {
ctx := context.Background()
conn, err := grpc.DialContext(ctx, kubeletSocket, grpc.WithInsecure(),
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
Expand All @@ -308,7 +308,7 @@ func registerWithKubelet(kubeletSocket, pluginEndPoint, resourceName string, opt
Version: pluginapi.Version,
Endpoint: pluginEndPoint,
ResourceName: resourceName,
Options: options,
Options: srv.getDevicePluginOptions(),
}

_, err = client.Register(ctx, reqt)
Expand Down
60 changes: 31 additions & 29 deletions pkg/deviceplugin/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ func newKubeletStub(socket string) *kubeletStub {
}
}

// newTestServer returns a server with devices for testing purposes.
func newTestServer() *server {
return &server{
devType: "testtype",
devices: map[string]DeviceInfo{
"dev1": {
state: pluginapi.Healthy,
},
"dev2": {
state: pluginapi.Healthy,
},
},
updatesCh: make(chan map[string]DeviceInfo, 1),
}
}

// Minimal implementation of deviceplugin.RegistrationServer interface

func (k *kubeletStub) Register(ctx context.Context, r *pluginapi.RegisterRequest) (*pluginapi.Empty, error) {
Expand Down Expand Up @@ -89,7 +105,9 @@ func (k *kubeletStub) start() error {
func TestRegisterWithKublet(t *testing.T) {
pluginSocket := path.Join(devicePluginPath, pluginEndpoint)

err := registerWithKubelet(kubeletSocket, pluginSocket, resourceName, nil)
srv := newTestServer()

err := srv.registerWithKubelet(kubeletSocket, pluginSocket, resourceName)
if err == nil {
t.Error("No error triggered when kubelet is not accessible")
}
Expand All @@ -101,7 +119,7 @@ func TestRegisterWithKublet(t *testing.T) {
}
defer kubelet.server.Stop()

err = registerWithKubelet(kubeletSocket, pluginSocket, resourceName, nil)
err = srv.registerWithKubelet(kubeletSocket, pluginSocket, resourceName)
if err != nil {
t.Errorf("Can't register device plugin: %+v", err)
}
Expand All @@ -117,18 +135,7 @@ func TestSetupAndServe(t *testing.T) {
}
defer kubelet.server.Stop()

srv := &server{
devType: "testtype",
devices: map[string]DeviceInfo{
"dev1": {
state: pluginapi.Healthy,
},
"dev2": {
state: pluginapi.Healthy,
},
},
updatesCh: make(chan map[string]DeviceInfo),
}
srv := newTestServer()

defer maybeLogError(srv.Stop, "unable to stop server")
go maybeLogError(func() error {
Expand Down Expand Up @@ -220,7 +227,7 @@ func TestSetupAndServe(t *testing.T) {
}

func TestStop(t *testing.T) {
srv := &server{}
srv := newTestServer()
if err := srv.Stop(); err == nil {
t.Error("Calling Stop() before Serve() is successful")
}
Expand All @@ -234,7 +241,7 @@ func TestAllocate(t *testing.T) {
},
},
}
srv := &server{}
srv := newTestServer()

tcases := []struct {
name string
Expand Down Expand Up @@ -456,9 +463,8 @@ func TestListAndWatch(t *testing.T) {

for _, tt := range tcases {
devCh := make(chan map[string]DeviceInfo, len(tt.updates))
testServer := &server{
updatesCh: devCh,
}
testServer := newTestServer()
testServer.updatesCh = devCh

server := &listAndWatchServerStub{
testServer: testServer,
Expand All @@ -483,7 +489,7 @@ func TestListAndWatch(t *testing.T) {
}

func TestGetDevicePluginOptions(t *testing.T) {
srv := &server{}
srv := newTestServer()
if _, err := srv.GetDevicePluginOptions(context.Background(), nil); err != nil {
t.Errorf("unexpected error: %+v", err)
}
Expand All @@ -508,9 +514,8 @@ func TestPreStartContainer(t *testing.T) {
}
for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
srv := &server{
preStartContainer: tc.preStartContainer,
}
srv := newTestServer()
srv.preStartContainer = tc.preStartContainer
_, err := srv.PreStartContainer(context.Background(), nil)
if !tc.expectedError && err != nil {
t.Errorf("unexpected error: %v", err)
Expand Down Expand Up @@ -541,9 +546,8 @@ func TestGetPreferredAllocation(t *testing.T) {
}
for _, tc := range tcases {
t.Run(tc.name, func(t *testing.T) {
srv := &server{
getPreferredAllocation: tc.getPreferredAllocation,
}
srv := newTestServer()
srv.getPreferredAllocation = tc.getPreferredAllocation
_, err := srv.GetPreferredAllocation(context.Background(), nil)
if !tc.expectedError && err != nil {
t.Errorf("unexpected error: %v", err)
Expand All @@ -560,9 +564,7 @@ func TestNewServer(t *testing.T) {
}

func TestUpdate(t *testing.T) {
srv := &server{
updatesCh: make(chan map[string]DeviceInfo, 1),
}
srv := newTestServer()
srv.Update(make(map[string]DeviceInfo))
}

Expand Down