diff --git a/client.go b/client.go index 61772b6..aa37abc 100644 --- a/client.go +++ b/client.go @@ -5,10 +5,7 @@ package sse import ( - "bytes" "context" - "encoding/base64" - "errors" "fmt" "io" "net/http" @@ -19,19 +16,18 @@ import ( "gopkg.in/cenkalti/backoff.v1" ) -var ( - headerID = []byte("id:") - headerData = []byte("data:") - headerEvent = []byte("event:") - headerRetry = []byte("retry:") -) - func ClientMaxBufferSize(s int) func(c *Client) { return func(c *Client) { c.maxBufferSize = s } } +func ClientWithComments() func(c *Client) { + return func(c *Client) { + c.EventParseConfig.Comments = true + } +} + // ConnCallback defines a function to be called on a particular connection event type ConnCallback func(c *Client) @@ -53,8 +49,8 @@ type Client struct { LastEventID atomic.Value // []byte maxBufferSize int mu sync.Mutex - EncodingBase64 bool Connected bool + EventParseConfig } // NewClient creates a new client @@ -233,7 +229,7 @@ func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan c // If we get an error, ignore it. var msg *Event - if msg, err = c.processEvent(event); err == nil { + if msg, err = ParseEvent(event, c.EventParseConfig); err == nil { if len(msg.ID) > 0 { c.LastEventID.Store(msg.ID) } else { @@ -241,7 +237,7 @@ func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan c } // Send downstream if the event has something useful - if msg.hasContent() { + if msg.hasContent() || (c.EventParseConfig.Comments && msg.hasComment()) { outCh <- msg } } @@ -319,49 +315,6 @@ func (c *Client) request(ctx context.Context, stream string) (*http.Response, er return c.Connection.Do(req) } -func (c *Client) processEvent(msg []byte) (event *Event, err error) { - var e Event - - if len(msg) < 1 { - return nil, errors.New("event message was empty") - } - - // Normalize the crlf to lf to make it easier to split the lines. - // Split the line by "\n" or "\r", per the spec. - for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) { - switch { - case bytes.HasPrefix(line, headerID): - e.ID = append([]byte(nil), trimHeader(len(headerID), line)...) - case bytes.HasPrefix(line, headerData): - // The spec allows for multiple data fields per event, concatenated them with "\n". - e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...) - // The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body. - case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))): - e.Data = append(e.Data, byte('\n')) - case bytes.HasPrefix(line, headerEvent): - e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...) - case bytes.HasPrefix(line, headerRetry): - e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...) - default: - // Ignore any garbage that doesn't match what we're looking for. - } - } - - // Trim the last "\n" per the spec. - e.Data = bytes.TrimSuffix(e.Data, []byte("\n")) - - if c.EncodingBase64 { - buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data))) - - n, err := base64.StdEncoding.Decode(buf, e.Data) - if err != nil { - err = fmt.Errorf("failed to decode event message: %s", err) - } - e.Data = buf[:n] - } - return &e, err -} - func (c *Client) cleanup(ch chan *Event) { c.mu.Lock() defer c.mu.Unlock() @@ -371,20 +324,3 @@ func (c *Client) cleanup(ch chan *Event) { delete(c.subscribed, ch) } } - -func trimHeader(size int, data []byte) []byte { - if data == nil || len(data) < size { - return data - } - - data = data[size:] - // Remove optional leading whitespace - if len(data) > 0 && data[0] == 32 { - data = data[1:] - } - // Remove trailing new line - if len(data) > 0 && data[len(data)-1] == 10 { - data = data[:len(data)-1] - } - return data -} diff --git a/client_test.go b/client_test.go index 3468df0..6b3e7d3 100644 --- a/client_test.go +++ b/client_test.go @@ -355,7 +355,7 @@ func TestClientLargeData(t *testing.T) { require.Equal(t, data, d) } -func TestClientComment(t *testing.T) { +func TestClientCommentIgnored(t *testing.T) { srv = newServer() defer cleanup() @@ -375,6 +375,30 @@ func TestClientComment(t *testing.T) { c.Unsubscribe(events) } +func TestClientWithComments(t *testing.T) { + srv = newServer() + defer cleanup() + + c := NewClient(urlPath, ClientWithComments()) + + events := make(chan *Event) + err := c.SubscribeChan("test", events) + require.Nil(t, err) + + srv.Publish("test", &Event{Comment: []byte("comment")}) + srv.Publish("test", &Event{Data: []byte("test")}) + + ev, err := waitEvent(events, time.Second*1) + assert.Nil(t, err) + assert.Equal(t, []byte("comment"), ev.Comment) + + ev, err = waitEvent(events, time.Second*1) + assert.Nil(t, err) + assert.Equal(t, []byte("test"), ev.Data) + + c.Unsubscribe(events) +} + func TestTrimHeader(t *testing.T) { tests := []struct { input []byte diff --git a/event.go b/event.go index 1258038..007d437 100644 --- a/event.go +++ b/event.go @@ -8,10 +8,27 @@ import ( "bufio" "bytes" "context" + "encoding/base64" + "errors" + "fmt" "io" + "net/http" "time" ) +var ( + headerID = []byte("id:") + headerData = []byte("data:") + headerEvent = []byte("event:") + headerRetry = []byte("retry:") + headerComment = []byte(":") +) + +type EventParseConfig struct { + EncodingBase64 bool + Comments bool +} + // Event holds all of the event source fields type Event struct { timestamp time.Time @@ -22,10 +39,138 @@ type Event struct { Comment []byte } +func ParseEvent(msg []byte, cfg EventParseConfig) (*Event, error) { + var e Event + var err error + + if len(msg) < 1 { + return nil, errors.New("event message was empty") + } + + // Normalize the crlf to lf to make it easier to split the lines. + // Split the line by "\n" or "\r", per the spec. + for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) { + switch { + case bytes.HasPrefix(line, headerID): + e.ID = append([]byte(nil), trimHeader(len(headerID), line)...) + case bytes.HasPrefix(line, headerData): + // The spec allows for multiple data fields per event, concatenated them with "\n". + e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...) + // The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body. + case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))): + e.Data = append(e.Data, byte('\n')) + case bytes.HasPrefix(line, headerEvent): + e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...) + case bytes.HasPrefix(line, headerRetry): + e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...) + case cfg.Comments && bytes.HasPrefix(line, headerComment): + e.Comment = append([]byte(nil), trimHeader(len(headerComment), line)...) + default: + // Ignore any garbage that doesn't match what we're looking for. + } + } + + // Trim the last "\n" per the spec. + e.Data = bytes.TrimSuffix(e.Data, []byte("\n")) + + if cfg.EncodingBase64 { + buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data))) + + n, decodeErr := base64.StdEncoding.Decode(buf, e.Data) + if decodeErr != nil { + err = fmt.Errorf("failed to decode event message: %s", decodeErr) + } + e.Data = buf[:n] + } + return &e, err +} + +func trimHeader(size int, data []byte) []byte { + if data == nil || len(data) < size { + return data + } + + data = data[size:] + // Remove optional leading whitespace + if len(data) > 0 && data[0] == 32 { + data = data[1:] + } + // Remove trailing new line + if len(data) > 0 && data[len(data)-1] == 10 { + data = data[:len(data)-1] + } + return data +} + +type EventWriteConfig struct { + SplitData bool +} + +func (e *Event) Write(w io.Writer, cfg EventWriteConfig) (int, error) { + var nWritten int + writef := func(format string, a ...interface{}) error { + n, err := fmt.Fprintf(w, format, a...) + nWritten += n + return err + } + + if len(e.Data) > 0 { + if err := writef("id: %s\n", e.ID); err != nil { + return nWritten, err + } + + if cfg.SplitData { + sd := bytes.Split(e.Data, []byte("\n")) + for i := range sd { + if err := writef("data: %s\n", sd[i]); err != nil { + return nWritten, err + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } + } else { + if bytes.HasPrefix(e.Data, []byte(":")) { + if err := writef("%s\n", e.Data); err != nil { + return nWritten, err + } + } else { + if err := writef("data: %s\n", e.Data); err != nil { + return nWritten, err + } + } + } + + if len(e.Event) > 0 { + if err := writef("event: %s\n", e.Event); err != nil { + return nWritten, err + } + } + + if len(e.Retry) > 0 { + if err := writef("retry: %s\n", e.Retry); err != nil { + return nWritten, err + } + } + } + + if len(e.Comment) > 0 { + if err := writef(": %s\n", e.Comment); err != nil { + return nWritten, err + } + } + + return nWritten, nil +} + func (e *Event) hasContent() bool { return len(e.ID) > 0 || len(e.Data) > 0 || len(e.Event) > 0 || len(e.Retry) > 0 } +func (e *Event) hasComment() bool { + return len(e.Comment) > 0 +} + // EventStreamReader scans an io.Reader looking for EventStream messages. type EventStreamReader struct { scanner *bufio.Scanner diff --git a/http.go b/http.go index c7a2b43..3dd7b70 100644 --- a/http.go +++ b/http.go @@ -5,7 +5,6 @@ package sse import ( - "bytes" "fmt" "net/http" "strconv" @@ -84,34 +83,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { continue } - if len(ev.Data) > 0 { - fmt.Fprintf(w, "id: %s\n", ev.ID) - - if s.SplitData { - sd := bytes.Split(ev.Data, []byte("\n")) - for i := range sd { - fmt.Fprintf(w, "data: %s\n", sd[i]) - } - } else { - if bytes.HasPrefix(ev.Data, []byte(":")) { - fmt.Fprintf(w, "%s\n", ev.Data) - } else { - fmt.Fprintf(w, "data: %s\n", ev.Data) - } - } - - if len(ev.Event) > 0 { - fmt.Fprintf(w, "event: %s\n", ev.Event) - } - - if len(ev.Retry) > 0 { - fmt.Fprintf(w, "retry: %s\n", ev.Retry) - } - } - - if len(ev.Comment) > 0 { - fmt.Fprintf(w, ": %s\n", ev.Comment) - } + ev.Write(w, EventWriteConfig{SplitData: s.SplitData}) fmt.Fprint(w, "\n")