Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 9 additions & 73 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
package sse

import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -19,19 +16,18 @@ import (
"gopkg.in/cenkalti/backoff.v1"
)

var (
headerID = []byte("id:")
headerData = []byte("data:")
headerEvent = []byte("event:")
headerRetry = []byte("retry:")
)

func ClientMaxBufferSize(s int) func(c *Client) {
return func(c *Client) {
c.maxBufferSize = s
}
}

func ClientWithComments() func(c *Client) {
return func(c *Client) {
c.EventParseConfig.Comments = true
}
}

// ConnCallback defines a function to be called on a particular connection event
type ConnCallback func(c *Client)

Expand All @@ -53,8 +49,8 @@ type Client struct {
LastEventID atomic.Value // []byte
maxBufferSize int
mu sync.Mutex
EncodingBase64 bool
Connected bool
EventParseConfig
}

// NewClient creates a new client
Expand Down Expand Up @@ -233,15 +229,15 @@ func (c *Client) readLoop(reader *EventStreamReader, outCh chan *Event, erChan c

// If we get an error, ignore it.
var msg *Event
if msg, err = c.processEvent(event); err == nil {
if msg, err = ParseEvent(event, c.EventParseConfig); err == nil {
if len(msg.ID) > 0 {
c.LastEventID.Store(msg.ID)
} else {
msg.ID, _ = c.LastEventID.Load().([]byte)
}

// Send downstream if the event has something useful
if msg.hasContent() {
if msg.hasContent() || (c.EventParseConfig.Comments && msg.hasComment()) {
outCh <- msg
}
}
Expand Down Expand Up @@ -319,49 +315,6 @@ func (c *Client) request(ctx context.Context, stream string) (*http.Response, er
return c.Connection.Do(req)
}

func (c *Client) processEvent(msg []byte) (event *Event, err error) {
var e Event

if len(msg) < 1 {
return nil, errors.New("event message was empty")
}

// Normalize the crlf to lf to make it easier to split the lines.
// Split the line by "\n" or "\r", per the spec.
for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) {
switch {
case bytes.HasPrefix(line, headerID):
e.ID = append([]byte(nil), trimHeader(len(headerID), line)...)
case bytes.HasPrefix(line, headerData):
// The spec allows for multiple data fields per event, concatenated them with "\n".
e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...)
// The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body.
case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))):
e.Data = append(e.Data, byte('\n'))
case bytes.HasPrefix(line, headerEvent):
e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...)
case bytes.HasPrefix(line, headerRetry):
e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...)
default:
// Ignore any garbage that doesn't match what we're looking for.
}
}

// Trim the last "\n" per the spec.
e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))

if c.EncodingBase64 {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data)))

n, err := base64.StdEncoding.Decode(buf, e.Data)
if err != nil {
err = fmt.Errorf("failed to decode event message: %s", err)
}
e.Data = buf[:n]
}
return &e, err
}

func (c *Client) cleanup(ch chan *Event) {
c.mu.Lock()
defer c.mu.Unlock()
Expand All @@ -371,20 +324,3 @@ func (c *Client) cleanup(ch chan *Event) {
delete(c.subscribed, ch)
}
}

func trimHeader(size int, data []byte) []byte {
if data == nil || len(data) < size {
return data
}

data = data[size:]
// Remove optional leading whitespace
if len(data) > 0 && data[0] == 32 {
data = data[1:]
}
// Remove trailing new line
if len(data) > 0 && data[len(data)-1] == 10 {
data = data[:len(data)-1]
}
return data
}
26 changes: 25 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func TestClientLargeData(t *testing.T) {
require.Equal(t, data, d)
}

func TestClientComment(t *testing.T) {
func TestClientCommentIgnored(t *testing.T) {
srv = newServer()
defer cleanup()

Expand All @@ -375,6 +375,30 @@ func TestClientComment(t *testing.T) {
c.Unsubscribe(events)
}

func TestClientWithComments(t *testing.T) {
srv = newServer()
defer cleanup()

c := NewClient(urlPath, ClientWithComments())

events := make(chan *Event)
err := c.SubscribeChan("test", events)
require.Nil(t, err)

srv.Publish("test", &Event{Comment: []byte("comment")})
srv.Publish("test", &Event{Data: []byte("test")})

ev, err := waitEvent(events, time.Second*1)
assert.Nil(t, err)
assert.Equal(t, []byte("comment"), ev.Comment)

ev, err = waitEvent(events, time.Second*1)
assert.Nil(t, err)
assert.Equal(t, []byte("test"), ev.Data)

c.Unsubscribe(events)
}

func TestTrimHeader(t *testing.T) {
tests := []struct {
input []byte
Expand Down
145 changes: 145 additions & 0 deletions event.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,27 @@ import (
"bufio"
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"time"
)

var (
headerID = []byte("id:")
headerData = []byte("data:")
headerEvent = []byte("event:")
headerRetry = []byte("retry:")
headerComment = []byte(":")
)

type EventParseConfig struct {
EncodingBase64 bool
Comments bool
}

// Event holds all of the event source fields
type Event struct {
timestamp time.Time
Expand All @@ -22,10 +39,138 @@ type Event struct {
Comment []byte
}

func ParseEvent(msg []byte, cfg EventParseConfig) (*Event, error) {
var e Event
var err error

if len(msg) < 1 {
return nil, errors.New("event message was empty")
}

// Normalize the crlf to lf to make it easier to split the lines.
// Split the line by "\n" or "\r", per the spec.
for _, line := range bytes.FieldsFunc(msg, func(r rune) bool { return r == '\n' || r == '\r' }) {
switch {
case bytes.HasPrefix(line, headerID):
e.ID = append([]byte(nil), trimHeader(len(headerID), line)...)
case bytes.HasPrefix(line, headerData):
// The spec allows for multiple data fields per event, concatenated them with "\n".
e.Data = append(e.Data[:], append(trimHeader(len(headerData), line), byte('\n'))...)
// The spec says that a line that simply contains the string "data" should be treated as a data field with an empty body.
case bytes.Equal(line, bytes.TrimSuffix(headerData, []byte(":"))):
e.Data = append(e.Data, byte('\n'))
case bytes.HasPrefix(line, headerEvent):
e.Event = append([]byte(nil), trimHeader(len(headerEvent), line)...)
case bytes.HasPrefix(line, headerRetry):
e.Retry = append([]byte(nil), trimHeader(len(headerRetry), line)...)
case cfg.Comments && bytes.HasPrefix(line, headerComment):
e.Comment = append([]byte(nil), trimHeader(len(headerComment), line)...)
default:
// Ignore any garbage that doesn't match what we're looking for.
}
}

// Trim the last "\n" per the spec.
e.Data = bytes.TrimSuffix(e.Data, []byte("\n"))

if cfg.EncodingBase64 {
buf := make([]byte, base64.StdEncoding.DecodedLen(len(e.Data)))

n, decodeErr := base64.StdEncoding.Decode(buf, e.Data)
if decodeErr != nil {
err = fmt.Errorf("failed to decode event message: %s", decodeErr)
}
e.Data = buf[:n]
}
return &e, err
}

func trimHeader(size int, data []byte) []byte {
if data == nil || len(data) < size {
return data
}

data = data[size:]
// Remove optional leading whitespace
if len(data) > 0 && data[0] == 32 {
data = data[1:]
}
// Remove trailing new line
if len(data) > 0 && data[len(data)-1] == 10 {
data = data[:len(data)-1]
}
return data
}

type EventWriteConfig struct {
SplitData bool
}

func (e *Event) Write(w io.Writer, cfg EventWriteConfig) (int, error) {
var nWritten int
writef := func(format string, a ...interface{}) error {
n, err := fmt.Fprintf(w, format, a...)
nWritten += n
return err
}

if len(e.Data) > 0 {
if err := writef("id: %s\n", e.ID); err != nil {
return nWritten, err
}

if cfg.SplitData {
sd := bytes.Split(e.Data, []byte("\n"))
for i := range sd {
if err := writef("data: %s\n", sd[i]); err != nil {
return nWritten, err
}
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
} else {
if bytes.HasPrefix(e.Data, []byte(":")) {
if err := writef("%s\n", e.Data); err != nil {
return nWritten, err
}
} else {
if err := writef("data: %s\n", e.Data); err != nil {
return nWritten, err
}
}
}

if len(e.Event) > 0 {
if err := writef("event: %s\n", e.Event); err != nil {
return nWritten, err
}
}

if len(e.Retry) > 0 {
if err := writef("retry: %s\n", e.Retry); err != nil {
return nWritten, err
}
}
}

if len(e.Comment) > 0 {
if err := writef(": %s\n", e.Comment); err != nil {
return nWritten, err
}
}

return nWritten, nil
}

func (e *Event) hasContent() bool {
return len(e.ID) > 0 || len(e.Data) > 0 || len(e.Event) > 0 || len(e.Retry) > 0
}

func (e *Event) hasComment() bool {
return len(e.Comment) > 0
}

// EventStreamReader scans an io.Reader looking for EventStream messages.
type EventStreamReader struct {
scanner *bufio.Scanner
Expand Down
30 changes: 1 addition & 29 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package sse

import (
"bytes"
"fmt"
"net/http"
"strconv"
Expand Down Expand Up @@ -84,34 +83,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
continue
}

if len(ev.Data) > 0 {
fmt.Fprintf(w, "id: %s\n", ev.ID)

if s.SplitData {
sd := bytes.Split(ev.Data, []byte("\n"))
for i := range sd {
fmt.Fprintf(w, "data: %s\n", sd[i])
}
} else {
if bytes.HasPrefix(ev.Data, []byte(":")) {
fmt.Fprintf(w, "%s\n", ev.Data)
} else {
fmt.Fprintf(w, "data: %s\n", ev.Data)
}
}

if len(ev.Event) > 0 {
fmt.Fprintf(w, "event: %s\n", ev.Event)
}

if len(ev.Retry) > 0 {
fmt.Fprintf(w, "retry: %s\n", ev.Retry)
}
}

if len(ev.Comment) > 0 {
fmt.Fprintf(w, ": %s\n", ev.Comment)
}
ev.Write(w, EventWriteConfig{SplitData: s.SplitData})

fmt.Fprint(w, "\n")

Expand Down