Skip to content

Commit 873ef3e

Browse files
committed
PR: simplify type checking
Signed-off-by: Hamza El-Saawy <[email protected]>
1 parent b034afc commit 873ef3e

File tree

1 file changed

+17
-29
lines changed

1 file changed

+17
-29
lines changed

Diff for: hvsock_test.go

+17-29
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,7 @@ func clientServer(u testUtil) (cl, sv *HvsockConn, _ *HvsockAddr) {
5757
if err != nil {
5858
return fmt.Errorf("listener accept: %w", err)
5959
}
60-
var ok bool
61-
sv, ok = conn.(*HvsockConn)
62-
if !ok {
63-
return fmt.Errorf("expected connection type %T; got %T", new(HvsockConn), conn)
64-
}
60+
sv = mustBeType[*HvsockConn](u.T, conn)
6561
if err := l.Close(); err != nil {
6662
return err
6763
}
@@ -113,10 +109,7 @@ func TestHvSockListenerAddresses(t *testing.T) {
113109
u := newUtil(t)
114110
l, addr := serverListen(u)
115111

116-
la, ok := (l.Addr()).(*HvsockAddr)
117-
if !ok {
118-
t.Fatalf("expected type %T; got %T", new(HvsockAddr), l.Addr())
119-
}
112+
la := mustBeType[*HvsockAddr](t, l.Addr())
120113
u.Assert(*la == *addr, fmt.Sprintf("give: %v; want: %v", la, addr))
121114

122115
ra := rawHvsockAddr{}
@@ -130,22 +123,10 @@ func TestHvSockAddresses(t *testing.T) {
130123
u := newUtil(t)
131124
cl, sv, addr := clientServer(u)
132125

133-
sra, ok := (sv.RemoteAddr()).(*HvsockAddr)
134-
if !ok {
135-
t.Fatalf("expected type %T; got %T", new(HvsockAddr), sv.RemoteAddr())
136-
}
137-
sla, ok := (sv.LocalAddr()).(*HvsockAddr)
138-
if !ok {
139-
t.Fatalf("expected type %T; got %T", new(HvsockAddr), sv.LocalAddr())
140-
}
141-
cra, ok := (cl.RemoteAddr()).(*HvsockAddr)
142-
if !ok {
143-
t.Fatalf("expected type %T; got %T", new(HvsockAddr), cl.RemoteAddr())
144-
}
145-
cla, ok := (cl.LocalAddr()).(*HvsockAddr)
146-
if !ok {
147-
t.Fatalf("expected type %T; got %T", new(HvsockAddr), cl.LocalAddr())
148-
}
126+
sra := mustBeType[*HvsockAddr](t, sv.RemoteAddr())
127+
sla := mustBeType[*HvsockAddr](t, sv.LocalAddr())
128+
cra := mustBeType[*HvsockAddr](t, cl.RemoteAddr())
129+
cla := mustBeType[*HvsockAddr](t, cl.LocalAddr())
149130

150131
t.Run("Info", func(t *testing.T) {
151132
tests := []struct {
@@ -341,10 +322,7 @@ func TestHvSockCloseReadWriteListener(t *testing.T) {
341322
}
342323
defer c.Close()
343324

344-
hv, ok := c.(*HvsockConn)
345-
if !ok {
346-
t.Fatalf("expected type %T; got %T", new(HvsockConn), c)
347-
}
325+
hv := mustBeType[*HvsockConn](t, c)
348326
//
349327
// test CloseWrite()
350328
//
@@ -683,3 +661,13 @@ func (u testUtil) Check() {
683661
func msgJoin(pre []string, s string) string {
684662
return strings.Join(append(pre, s), ": ")
685663
}
664+
665+
func mustBeType[T any](tb testing.TB, v any) T {
666+
tb.Helper()
667+
668+
v2, ok := v.(T)
669+
if !ok {
670+
tb.Fatalf("expected type %T; got %T", *new(T), v)
671+
}
672+
return v2
673+
}

0 commit comments

Comments
 (0)