From a4397d7bca5c41eea9d076b4cc6448ed042425bc Mon Sep 17 00:00:00 2001 From: cmilhench Date: Fri, 9 Aug 2024 17:47:35 +0100 Subject: [PATCH] add socket server --- Makefile | 2 +- cmd/chat/main.go | 61 ++++++++++ cmd/chat/static/index.html | 176 ++++++++++++++++++++++++++++ exp/http/socket/client.go | 60 ++++++++++ exp/http/socket/server.go | 119 +++++++++++++++++++ exp/http/static/static.go | 35 ++++++ exp/identifiers/identifiers.go | 13 +- exp/identifiers/identifiers_test.go | 8 +- exp/irc/message.go | 79 +++++++++++++ exp/irc/message_test.go | 40 +++++++ go.mod | 4 +- go.sum | 2 + 12 files changed, 584 insertions(+), 15 deletions(-) create mode 100644 cmd/chat/main.go create mode 100644 cmd/chat/static/index.html create mode 100644 exp/http/socket/client.go create mode 100644 exp/http/socket/server.go create mode 100644 exp/http/static/static.go create mode 100644 exp/irc/message.go create mode 100644 exp/irc/message_test.go diff --git a/Makefile b/Makefile index 7e8a5a6..477dbfe 100644 --- a/Makefile +++ b/Makefile @@ -56,7 +56,7 @@ test: lint ## Run the project tests start: test ## Start the server $(call cyan, "Running...") $(call setenv,) - @go run -ldflags '-w -s ' ./cmd/ + @go run -ldflags '-w -s ' ./cmd/chat/ .PHONY: start watch: ## Run locally and monitor for changes diff --git a/cmd/chat/main.go b/cmd/chat/main.go new file mode 100644 index 0000000..4536ceb --- /dev/null +++ b/cmd/chat/main.go @@ -0,0 +1,61 @@ +package main + +import ( + "embed" + "fmt" + "log" + "net/http" + "time" + + "github.com/cmilhench/x/exp/http/socket" + "github.com/cmilhench/x/exp/http/static" + "github.com/cmilhench/x/exp/irc" +) + +//go:embed static +var fs embed.FS + +func main() { + server := socket.NewSocketServer() + server.Handle(socketHandler(server)) + server.Start() + + http.Handle("/", http.FileServer(static.Neutered{Prefix: "static", FileSystem: http.FS(fs)})) + http.HandleFunc("/ws", server.HandleConnections) + + log.Println("Socket server started on :8080") + err := http.ListenAndServe(":8080", nil) + if err != nil { + log.Fatalf("ListenAndServe: %v", err) + } +} + +func socketHandler(server *socket.SocketServer) socket.MessageHandler { + return func(client *socket.Client, messageBytes []byte) { + message := irc.ParseMessage(string(messageBytes)) + log.Printf("message ->: %#v", message) + switch message.Command { + case "INFO": // returns information about the server + client.Send([]byte(fmt.Sprintf("INFO %s", "This is an IRC server."))) + case "MOTD": // returns the message of the day + client.Send([]byte(fmt.Sprintf("MOTD %s", "Welcome to the IRC server!"))) + case "NICK": // allows a client to change their IRC nickname. + client.Name = message.Params + case "PING": // tests the presence of a connection + client.Send([]byte(fmt.Sprintf("PONG %s", message.Params))) + case "NOTICE", "PRIVMSG": // Sends to , which is usually a user or channel. + if message.Params[0] == '#' { + server.Broadcast([]byte(fmt.Sprintf(":%s PRIVMSG %s", client.Name, message.Trailing))) + } else { + server.Send(message.Params, []byte(fmt.Sprintf(":%s PRIVMSG %s", client.Name, message.Trailing))) + } + case "QUIT": // disconnects the user from the server. + server.Part(client) + case "TIME": // returns the current time on the server + client.Send([]byte(time.Now().Format(time.RFC1123Z))) + case "TOPIC": // sets the topic of to + default: + log.Printf("Unknown message type: %#v", message) + } + } +} diff --git a/cmd/chat/static/index.html b/cmd/chat/static/index.html new file mode 100644 index 0000000..70cdabd --- /dev/null +++ b/cmd/chat/static/index.html @@ -0,0 +1,176 @@ + + + + + + WebSocket Client + + + +

WebSocket Client

+ +
+ + +
+ +

Messages:

+
+ + + + + + + + + diff --git a/exp/http/socket/client.go b/exp/http/socket/client.go new file mode 100644 index 0000000..80ff6c0 --- /dev/null +++ b/exp/http/socket/client.go @@ -0,0 +1,60 @@ +package socket + +import ( + "log" + "time" + + "github.com/cmilhench/x/exp/uuid" + + "github.com/gorilla/websocket" +) + +type Client struct { + conn *websocket.Conn + send chan []byte + id string + Name string +} + +type MessageHandler func(*Client, []byte) + +func NewClient(conn *websocket.Conn) *Client { + id, _ := uuid.New() + return &Client{ + id: id, + conn: conn, + send: make(chan []byte), + } +} + +func (client *Client) ReadMessages(fn MessageHandler) { + for { + _, msg, err := client.conn.ReadMessage() + if err != nil { + log.Printf("Read error: %v", err) + break + } + fn(client, msg) + } +} + +func (client *Client) WriteMessages() { + for msg := range client.send { + err := client.conn.WriteMessage(websocket.BinaryMessage, msg) + if err != nil { + log.Printf("Write error: %v", err) + break + } + } +} + +func (client *Client) Send(data []byte) { + client.send <- data +} + +func (client *Client) Close() { + close(client.send) + deadline := time.Now().Add(5 * time.Second) + data := websocket.FormatCloseMessage(websocket.CloseNormalClosure, "") + _ = client.conn.WriteControl(websocket.CloseMessage, data, deadline) +} diff --git a/exp/http/socket/server.go b/exp/http/socket/server.go new file mode 100644 index 0000000..5dc51ea --- /dev/null +++ b/exp/http/socket/server.go @@ -0,0 +1,119 @@ +package socket + +import ( + "log" + "net/http" + "sync" + + "github.com/gorilla/websocket" +) + +type SocketServer struct { + clients map[*Client]struct{} + broadcast chan []byte + messages chan struct { + Target string + Data []byte + } + join chan *Client + part chan *Client + handler MessageHandler + mu sync.Mutex +} + +func NewSocketServer() *SocketServer { + return &SocketServer{ + clients: make(map[*Client]struct{}), + broadcast: make(chan []byte), + messages: make(chan struct { + Target string + Data []byte + }), + join: make(chan *Client), + part: make(chan *Client), + } +} + +func (s *SocketServer) Start() { + go func() { + for { + select { + case client := <-s.join: + s.mu.Lock() + s.clients[client] = struct{}{} + s.mu.Unlock() + log.Printf("Client joined: %v", client.conn.RemoteAddr()) + case client := <-s.part: + s.mu.Lock() + if _, ok := s.clients[client]; ok { + client.Close() + delete(s.clients, client) + log.Printf("Client left: %v", client.conn.RemoteAddr()) + } + s.mu.Unlock() + case data := <-s.broadcast: + s.mu.Lock() + for client := range s.clients { + select { + case client.send <- data: + default: + close(client.send) + delete(s.clients, client) + } + } + s.mu.Unlock() + case message := <-s.messages: + s.mu.Lock() + for k := range s.clients { + if k.id == message.Target || k.Name == message.Target { + select { + case k.send <- message.Data: + default: + close(k.send) + delete(s.clients, k) + } + return + } + } + s.mu.Unlock() + } + } + }() +} + +func (s *SocketServer) Broadcast(message []byte) { + s.broadcast <- message +} + +func (s *SocketServer) Handle(handler MessageHandler) { + s.handler = handler +} + +func (s *SocketServer) Send(target string, message []byte) { + s.messages <- struct { + Target string + Data []byte + }{target, message} +} + +func (s *SocketServer) Part(client *Client) { + s.part <- client +} + +func (s *SocketServer) HandleConnections(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Upgrade error: %v", err) + return + } + client := NewClient(conn) + + s.join <- client + + go client.WriteMessages() + client.ReadMessages(s.handler) +} diff --git a/exp/http/static/static.go b/exp/http/static/static.go new file mode 100644 index 0000000..87423ec --- /dev/null +++ b/exp/http/static/static.go @@ -0,0 +1,35 @@ +package static + +import ( + "net/http" + "path" + "path/filepath" +) + +// neutered is a http file system wrapper that disables FileServer Directory Listings +// and roots every path in /static +type Neutered struct { + Prefix string + FileSystem http.FileSystem +} + +func (n Neutered) Open(name string) (http.File, error) { + name = path.Join(n.Prefix, name) + f, err := n.FileSystem.Open(name) + if err != nil { + return nil, err + } + s, _ := f.Stat() + if s.IsDir() { + index := filepath.Join(name, "index.html") + if _, err := n.FileSystem.Open(index); err != nil { + closeErr := f.Close() + if closeErr != nil { + return nil, closeErr + } + + return nil, err + } + } + return f, nil +} diff --git a/exp/identifiers/identifiers.go b/exp/identifiers/identifiers.go index ca07ef8..4142509 100644 --- a/exp/identifiers/identifiers.go +++ b/exp/identifiers/identifiers.go @@ -1,7 +1,6 @@ package identifiers import ( - "strconv" "time" ) @@ -9,13 +8,13 @@ import ( // 41 bits = milliseconds from epoch (max:2199023255551 = ~69 years) // 10 bits = shard (max:1024) // 12 bits = auto-incrementing and wrapping index (max:4095) see % -func Creator(shard uint16) func() string { +func Creator(shard uint16) func() uint64 { e := int64(1577836800000) // time.Parse(time.RFC3339, "2020-01-01T00:00:00Z") l := time.Now().UnixMilli() - e i := shard % 1024 s := 0 time.Sleep(time.Millisecond) - return func() string { + return func() uint64 { var hash uint64 n := time.Now().UnixMilli() - e if n == l { @@ -39,17 +38,13 @@ func Creator(shard uint16) func() string { // set the last 12 bits (%4095) of the uint64 by shifting S left by 0 // 0000000000000000000000000000000000000000000000000000111111111111 hash |= uint64(s) - return strconv.FormatUint(hash, 36) + return hash } } // Parse an identifier into it's components -func Parse(key string) (time.Time, uint64, uint64, error) { +func Parse(hash uint64) (time.Time, uint64, uint64, error) { e := int64(1577836800000) // time.Parse(time.RFC3339, "2020-01-01T00:00:00Z") - hash, err := strconv.ParseUint(key, 36, 64) - if err != nil { - return time.Time{}, 0, 0, err - } n := (hash << (1)) >> 23 i := (hash << (42)) >> 54 s := (hash << (52)) >> 52 diff --git a/exp/identifiers/identifiers_test.go b/exp/identifiers/identifiers_test.go index c3eba85..3e20e86 100644 --- a/exp/identifiers/identifiers_test.go +++ b/exp/identifiers/identifiers_test.go @@ -6,8 +6,8 @@ import ( ) func Test_Creator(t *testing.T) { - ids := make(chan string) - store := make(map[string]bool) + ids := make(chan uint64) + store := make(map[uint64]bool) workers := 1024 var wg sync.WaitGroup @@ -31,13 +31,13 @@ func Test_Creator(t *testing.T) { for id := range ids { if _, exists := store[id]; exists { - t.Errorf("Duplicate ID generated: %s", id) + t.Errorf("Duplicate ID generated: %d", id) } store[id] = true _, _, _, err := Parse(id) if err != nil { - t.Errorf("Failed to parse ID: %s", id) + t.Errorf("Failed to parse ID: %d", id) } } } diff --git a/exp/irc/message.go b/exp/irc/message.go new file mode 100644 index 0000000..c40a06e --- /dev/null +++ b/exp/irc/message.go @@ -0,0 +1,79 @@ +package irc + +import ( + "fmt" + "strings" +) + +type Message struct { + Prefix string + Command string + Params string + Trailing string + _raw string +} + +func ParseMessage(line string) *Message { + c := &Message{} + c.Parse(line) + return c +} + +func (c *Message) Parse(line string) { + line = strings.TrimSuffix(line, "\r") + line = strings.TrimSuffix(line, "\r\n") + orig := line + c._raw = orig + // Prefix + if line[0] == ':' { + i := strings.Index(line, " ") + c.Prefix = line[1:i] + line = line[i+1:] + } + // Command + i := strings.Index(line, " ") + if i == -1 { + i = len(line) + } + c.Command = line[0:i] + line = line[i:] + // Params + i = strings.Index(line, " :") + if i == -1 { + i = len(line) + } + if i != 0 { + c.Params = line[1:i] + } + // Trailing + if len(line)-i > 2 { + c.Trailing = line[i+2:] + } +} + +func (c *Message) String() string { + var line string + if len(c.Prefix) > 0 { + line = fmt.Sprintf(":%s ", c.Prefix) + } + line += c.Command + if len(c.Params) > 0 { + line = fmt.Sprintf("%s %s", line, c.Params) + } + if len(c.Trailing) > 0 { + line = fmt.Sprintf("%s :%s", line, c.Trailing) + } + return line +} + +func (c *Message) Nick() string { + return c.Prefix[0:strings.Index(c.Prefix, "!")] +} + +func (c *Message) Username() string { + return c.Prefix[strings.Index(c.Prefix, "!")+1 : strings.Index(c.Prefix, "@")] +} + +func (c *Message) Hostname() string { + return c.Prefix[strings.Index(c.Prefix, "@")+1:] +} diff --git a/exp/irc/message_test.go b/exp/irc/message_test.go new file mode 100644 index 0000000..668741d --- /dev/null +++ b/exp/irc/message_test.go @@ -0,0 +1,40 @@ +package irc + +import ( + "fmt" + "testing" +) + +func TestMessage(t *testing.T) { + var tests = []struct { + name string + line string + Message + }{ + {"1", ":example.freenode.net NOTICE * :*** Looking up your hostname...\r\n", Message{"example.freenode.net", "NOTICE", "*", "*** Looking up your hostname...", ""}}, + {"2", "ERROR :Closing Link: 127.0.0.1 (Connection timed out)\r\n", Message{"", "ERROR", "", "Closing Link: 127.0.0.1 (Connection timed out)", ""}}, + {"3", ":user!~mail@example.net JOIN #channel\r\n", Message{"user!~mail@example.net", "JOIN", "#channel", "", ""}}, + {"4", ":user!~mail@example.com PRIVMSG user :Hello :)\r\n", Message{"user!~mail@example.com", "PRIVMSG", "user", "Hello :)", ""}}, + {"6", ":user!~mail@example.com PRIVMSG #channel :Hello :)\r\n", Message{"user!~mail@example.com", "PRIVMSG", "#channel", "Hello :)", ""}}, + {"6", ":NickServ!NickServ@services. NOTICE user :Some message.\r\n", Message{"NickServ!NickServ@services.", "NOTICE", "user", "Some message.", ""}}, + {"7", ":user PRIVMSG #chan :Hello!\r\n", Message{"user", "PRIVMSG", "#chan", "Hello!", ""}}, + } + for _, test := range tests { + t.Run(fmt.Sprintf("method%v", test.name), func(t *testing.T) { + m := Message{} + m.Parse(test.line) + if m.Prefix != test.Prefix { + t.Errorf("expected prefix '%s', got '%s'", test.Prefix, m.Prefix) + } + if m.Command != test.Command { + t.Errorf("expected command '%s', got '%s'", test.Command, m.Command) + } + if m.Params != test.Params { + t.Errorf("expected params '%s', got '%s'", test.Params, m.Params) + } + if m.Trailing != test.Trailing { + t.Errorf("expected trailing '%s', got '%s'", test.Trailing, m.Trailing) + } + }) + } +} diff --git a/go.mod b/go.mod index bbb8e0a..a0d122c 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/cmilhench/x -go 1.22.5 +go 1.22.6 + +require github.com/gorilla/websocket v1.5.3 diff --git a/go.sum b/go.sum index e69de29..25a9fc4 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=