Skip to content

Commit c942acd

Browse files
authored
feat: make server acknowledge request (#23)
1 parent 914fd83 commit c942acd

File tree

7 files changed

+125
-3
lines changed

7 files changed

+125
-3
lines changed

cmd/client/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ service, ip and hostname rather than only pods.`,
224224
cf.AddFlags(flags)
225225
flags.BoolVarP(&printVersion, "version", "V", false, "Print version info and exit.")
226226
flags.StringVar(&o.address, "address", "127.0.0.1", "Address to listen on. Only accepts IP addresses as a value.")
227-
flags.StringVar(&o.serverImage, "server.image", constants.ServerImage, "The krelay-server image to use.")
227+
flags.StringVar(&o.serverImage, "server.image", "ghcr.io/knight42/krelay-server:v0.0.2", "The krelay-server image to use.")
228228
flags.StringVar(&o.serverNamespace, "server.namespace", metav1.NamespaceDefault, "The namespace in which krelay-server is located.")
229229

230230
// I do not want these flags to show up in --help.

cmd/client/tcp.go

+11
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ func handleTCPConn(clientConn net.Conn, serverConn httpstream.Connection, dstAdd
4343
return
4444
}
4545

46+
var ack xnet.Acknowledgement
47+
err = ack.FromReader(dataStream)
48+
if err != nil {
49+
klog.ErrorS(err, "Fail to receive ack", kvs...)
50+
return
51+
}
52+
if ack.Code != xnet.AckCodeOK {
53+
klog.ErrorS(ack.Code, "Fail to connect", kvs...)
54+
return
55+
}
56+
4657
localError := make(chan struct{})
4758
remoteDone := make(chan struct{})
4859

cmd/client/udp.go

+11
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ func handleUDPConn(clientConn net.PacketConn, cliAddr net.Addr, dataCh chan []by
4343
return
4444
}
4545

46+
var ack xnet.Acknowledgement
47+
err = ack.FromReader(dataStream)
48+
if err != nil {
49+
klog.ErrorS(err, "Fail to receive ack", kvs...)
50+
return
51+
}
52+
if ack.Code != xnet.AckCodeOK {
53+
klog.ErrorS(ack.Code, "Fail to connect", kvs...)
54+
return
55+
}
56+
4657
upClosed := make(chan struct{})
4758
go func() {
4859
var (

cmd/server/main.go

+46
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"errors"
56
"flag"
67
"fmt"
78
"net"
@@ -40,6 +41,31 @@ func (o *options) run(ctx context.Context) error {
4041
}
4142
}
4243

44+
func writeACK(c net.Conn, ack xnet.Acknowledgement) error {
45+
data := ack.Marshal()
46+
_, err := c.Write(data)
47+
return err
48+
}
49+
50+
func ackCodeFromErr(err error) xnet.AckCode {
51+
var dnsErr *net.DNSError
52+
if errors.As(err, &dnsErr) {
53+
if dnsErr.IsNotFound {
54+
return xnet.AckCodeNoSuchHost
55+
}
56+
if dnsErr.IsTimeout {
57+
return xnet.AckCodeResolveTimeout
58+
}
59+
}
60+
61+
var opErr *net.OpError
62+
if errors.As(err, &opErr) && opErr.Timeout() {
63+
return xnet.AckCodeConnectTimeout
64+
}
65+
66+
return xnet.AckCodeUnknownError
67+
}
68+
4369
func handleConn(ctx context.Context, c *net.TCPConn, dialer *net.Dialer) {
4470
defer c.Close()
4571

@@ -57,6 +83,16 @@ func handleConn(ctx context.Context, c *net.TCPConn, dialer *net.Dialer) {
5783
upstreamConn, err := dialer.DialContext(ctx, constants.ProtocolTCP, dstAddr)
5884
if err != nil {
5985
klog.ErrorS(err, "Fail to create tcp connection", constants.LogFieldRequestID, hdr.RequestID, constants.LogFieldDestAddr, dstAddr)
86+
_ = writeACK(c, xnet.Acknowledgement{
87+
Code: ackCodeFromErr(err),
88+
})
89+
return
90+
}
91+
err = writeACK(c, xnet.Acknowledgement{
92+
Code: xnet.AckCodeOK,
93+
})
94+
if err != nil {
95+
klog.ErrorS(err, "Fail to write ack", constants.LogFieldRequestID, hdr.RequestID)
6096
return
6197
}
6298
klog.InfoS("Start proxy tcp request", constants.LogFieldRequestID, hdr.RequestID, constants.LogFieldDestAddr, dstAddr)
@@ -66,6 +102,16 @@ func handleConn(ctx context.Context, c *net.TCPConn, dialer *net.Dialer) {
66102
upstreamConn, err := dialer.DialContext(ctx, constants.ProtocolUDP, dstAddr)
67103
if err != nil {
68104
klog.ErrorS(err, "Fail to create udp connection", constants.LogFieldRequestID, hdr.RequestID, constants.LogFieldDestAddr, dstAddr)
105+
_ = writeACK(c, xnet.Acknowledgement{
106+
Code: ackCodeFromErr(err),
107+
})
108+
return
109+
}
110+
err = writeACK(c, xnet.Acknowledgement{
111+
Code: xnet.AckCodeOK,
112+
})
113+
if err != nil {
114+
klog.ErrorS(err, "Fail to write ack", constants.LogFieldRequestID, hdr.RequestID)
69115
return
70116
}
71117
klog.InfoS("Start proxy udp request", constants.LogFieldRequestID, hdr.RequestID, constants.LogFieldDestAddr, dstAddr)

cmd/server/tcp_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ func TestHandleTCPConn(t *testing.T) {
4949
if err != nil {
5050
return nil, fmt.Errorf("write header: %w", err)
5151
}
52+
var ack xnet.Acknowledgement
53+
err = ack.FromReader(c)
54+
if err != nil {
55+
return nil, fmt.Errorf("read ack: %w", err)
56+
}
57+
if ack.Code != xnet.AckCodeOK {
58+
return nil, fmt.Errorf("ack: %s", ack.Code.Error())
59+
}
5260
return c, nil
5361
}
5462

pkg/constants/constants.go

-2
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,3 @@ const (
2424
ProtocolTCP = "tcp"
2525
ProtocolUDP = "udp"
2626
)
27-
28-
const ServerImage = "ghcr.io/knight42/krelay-server:v0.0.1"

pkg/xnet/ack.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package xnet
2+
3+
import (
4+
"fmt"
5+
"io"
6+
)
7+
8+
type AckCode uint8
9+
10+
const (
11+
AckCodeOK = iota + 1
12+
AckCodeUnknownError
13+
AckCodeNoSuchHost
14+
AckCodeResolveTimeout
15+
AckCodeConnectTimeout
16+
)
17+
18+
func (c AckCode) Error() string {
19+
switch c {
20+
case AckCodeUnknownError:
21+
return "Unknown error"
22+
case AckCodeNoSuchHost:
23+
return "No such host"
24+
case AckCodeResolveTimeout:
25+
return "Resolve timeout"
26+
case AckCodeConnectTimeout:
27+
return "Connect timeout"
28+
}
29+
return "Unknown Code"
30+
}
31+
32+
type Acknowledgement struct {
33+
Code AckCode
34+
}
35+
36+
func (a *Acknowledgement) Marshal() []byte {
37+
return []byte{byte(a.Code)}
38+
}
39+
40+
func (a *Acknowledgement) FromReader(r io.Reader) error {
41+
var buf [1]byte
42+
_, err := r.Read(buf[:])
43+
if err != nil {
44+
return fmt.Errorf("read ack: %w", err)
45+
}
46+
a.Code = AckCode(buf[0])
47+
return nil
48+
}

0 commit comments

Comments
 (0)