diff --git a/pkg/ping/net.go b/pkg/ping/net.go index a355435..f1f80df 100644 --- a/pkg/ping/net.go +++ b/pkg/ping/net.go @@ -3,23 +3,21 @@ package ping import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "net" "time" ) -type result struct { - Status string - ResponseDuration time.Duration -} +var ErrInvalidResponse = errors.New("invalid response") -func OpenConnection(host string, port, timeout int) (result, error) { +func OpenConnection(host string, port, timeout int) (int, error) { address := fmt.Sprintf("%v:%v", host, port) conn, err := net.DialTimeout("tcp", address, time.Millisecond*time.Duration(timeout)) connectTime := time.Now() if err != nil { - return result{}, err + return 0, err } defer conn.Close() @@ -28,25 +26,20 @@ func OpenConnection(host string, port, timeout int) (result, error) { _, err = conn.Read(buf) responseTime := time.Now() if err != nil && err != io.EOF { - return result{}, err + return 0, err } var opcode uint16 reader := bytes.NewReader(buf[2:4]) err = binary.Read(reader, binary.LittleEndian, &opcode) if err != nil { - return result{}, err + return 0, err } - responseDuration := responseTime.Sub(connectTime).Round(time.Millisecond) - - status := "fail" - if opcode == SMSG_AUTH_CHALLENGE { - status = "success" + if opcode != SMSG_AUTH_CHALLENGE { + return 0, ErrInvalidResponse } - return result{ - Status: status, - ResponseDuration: responseDuration, - }, nil + res := responseTime.Sub(connectTime).Milliseconds() + return int(res), nil } diff --git a/wowPing.go b/wowPing.go index a42936c..4dd886a 100644 --- a/wowPing.go +++ b/wowPing.go @@ -40,14 +40,12 @@ func main() { for _, server := range group.List { stat := statistics[server.Name] - result, err := ping.OpenConnection(server.Host, server.Port, params.Timeout) + responseTime, err := ping.OpenConnection(server.Host, server.Port, params.Timeout) - if err == nil && result.Status == "success" { - stat.ResponseDurations = append(stat.ResponseDurations, int(result.ResponseDuration.Milliseconds())) - fmt.Println(server.Name, result.ResponseDuration) - } - - if err != nil { + if err == nil { + stat.ResponseDurations = append(stat.ResponseDurations, responseTime) + fmt.Printf("%v %vms\n", server.Name, responseTime) + } else { if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, os.ErrDeadlineExceeded) { stat.Timeouts++ fmt.Println(server.Name, "timeout")