Skip to content

Commit 57b9171

Browse files
committed
Use custom gomock.Matcher for protobuf messages
The default gomock.Eq matcher uses reflect.DeepEqual to test for equality, but as of protobuf 1.1.0, this will not work. Instead, proto.Equal() needs to be used. Thus, implement a custom protobuf matcher to test for message equality. Add additional test for ControllerPublishVolume, which was the RPC seen to trigger the protobuf comparison failure.
1 parent 1bf94ed commit 57b9171

File tree

1 file changed

+84
-1
lines changed

1 file changed

+84
-1
lines changed

test/co_test.go

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@ limitations under the License.
1616
package test
1717

1818
import (
19+
"fmt"
20+
"reflect"
1921
"testing"
2022

2123
csi "github.com/container-storage-interface/spec/lib/go/csi/v0"
2224
gomock "github.com/golang/mock/gomock"
25+
"github.com/golang/protobuf/proto"
2326
mock_driver "github.com/kubernetes-csi/csi-test/driver"
2427
mock_utils "github.com/kubernetes-csi/csi-test/utils"
2528
"golang.org/x/net/context"
@@ -58,6 +61,24 @@ func TestPluginInfoResponse(t *testing.T) {
5861
}
5962
}
6063

64+
type pbMatcher struct {
65+
x proto.Message
66+
}
67+
68+
func (p pbMatcher) Matches(x interface{}) bool {
69+
y := x.(proto.Message)
70+
return proto.Equal(p.x, y)
71+
}
72+
73+
func (p pbMatcher) String() string {
74+
return fmt.Sprintf("pb equal to %v", p.x)
75+
}
76+
77+
func pbMatch(x interface{}) gomock.Matcher {
78+
v := x.(proto.Message)
79+
return &pbMatcher{v}
80+
}
81+
6182
func TestGRPCGetPluginInfoReponse(t *testing.T) {
6283

6384
// Setup mock
@@ -79,7 +100,7 @@ func TestGRPCGetPluginInfoReponse(t *testing.T) {
79100

80101
// Setup expectation
81102
// !IMPORTANT!: Must set context expected value to gomock.Any() to match any value
82-
driver.EXPECT().GetPluginInfo(gomock.Any(), in).Return(out, nil).Times(1)
103+
driver.EXPECT().GetPluginInfo(gomock.Any(), pbMatch(in)).Return(out, nil).Times(1)
83104

84105
// Create a new RPC
85106
server := mock_driver.NewMockCSIDriver(&mock_driver.MockCSIDriverServers{
@@ -103,3 +124,65 @@ func TestGRPCGetPluginInfoReponse(t *testing.T) {
103124
t.Errorf("Unknown name: %s\n", name)
104125
}
105126
}
127+
128+
func TestGRPCAttach(t *testing.T) {
129+
130+
// Setup mock
131+
m := gomock.NewController(&mock_utils.SafeGoroutineTester{})
132+
defer m.Finish()
133+
driver := mock_driver.NewMockControllerServer(m)
134+
135+
// Setup input
136+
defaultVolumeID := "myname"
137+
defaultNodeID := "MyNodeID"
138+
defaultCaps := &csi.VolumeCapability{
139+
AccessType: &csi.VolumeCapability_Mount{
140+
Mount: &csi.VolumeCapability_MountVolume{},
141+
},
142+
AccessMode: &csi.VolumeCapability_AccessMode{
143+
Mode: csi.VolumeCapability_AccessMode_MULTI_NODE_MULTI_WRITER,
144+
},
145+
}
146+
publishVolumeInfo := map[string]string{
147+
"first": "foo",
148+
"second": "bar",
149+
"third": "baz",
150+
}
151+
defaultRequest := &csi.ControllerPublishVolumeRequest{
152+
VolumeId: defaultVolumeID,
153+
NodeId: defaultNodeID,
154+
VolumeCapability: defaultCaps,
155+
Readonly: false,
156+
}
157+
158+
// Setup mock outout
159+
out := &csi.ControllerPublishVolumeResponse{
160+
PublishInfo: publishVolumeInfo,
161+
}
162+
163+
// Setup expectation
164+
// !IMPORTANT!: Must set context expected value to gomock.Any() to match any value
165+
driver.EXPECT().ControllerPublishVolume(gomock.Any(), pbMatch(defaultRequest)).Return(out, nil).Times(1)
166+
167+
// Create a new RPC
168+
server := mock_driver.NewMockCSIDriver(&mock_driver.MockCSIDriverServers{
169+
Controller: driver,
170+
})
171+
conn, err := server.Nexus()
172+
if err != nil {
173+
t.Errorf("Error: %s", err.Error())
174+
}
175+
defer server.Close()
176+
177+
// Make call
178+
c := csi.NewControllerClient(conn)
179+
r, err := c.ControllerPublishVolume(context.Background(), defaultRequest)
180+
if err != nil {
181+
t.Errorf("Error: %s", err.Error())
182+
}
183+
184+
info := r.GetPublishInfo()
185+
if !reflect.DeepEqual(info, publishVolumeInfo) {
186+
t.Errorf("Invalid publish info: %v", info)
187+
}
188+
}

0 commit comments

Comments
 (0)