diff --git a/internal/broker/conn.go b/internal/broker/conn.go index 3e1cc91d..30a7fbfb 100644 --- a/internal/broker/conn.go +++ b/internal/broker/conn.go @@ -44,6 +44,7 @@ type Conn struct { service *Service // The service for this connection. subs *message.Counters // The subscriptions for this connection. measurer stats.Measurer // The measurer to use for monitoring. + links map[string]string // The map of all pre-authorized links. } // NewConn creates a new connection. @@ -55,6 +56,7 @@ func (s *Service) newConn(t net.Conn) *Conn { socket: t, subs: message.NewCounters(), measurer: s.measurer, + links: map[string]string{}, } // Generate a globally unique id as well @@ -137,11 +139,13 @@ func (c *Conn) onReceive(msg mqtt.Message) error { // We got an attempt to connect to MQTT. case mqtt.TypeOfConnect: - packet := msg.(*mqtt.Connect) - c.username = string(packet.Username) + var result uint8 + if !c.onConnect(msg.(*mqtt.Connect)) { + result = 0x05 // Unauthorized + } // Write the ack - ack := mqtt.Connack{ReturnCode: 0x00} + ack := mqtt.Connack{ReturnCode: result} if _, err := ack.EncodeTo(c.socket); err != nil { return err } diff --git a/internal/broker/handlers.go b/internal/broker/handlers.go index 1e39feec..098c14b3 100644 --- a/internal/broker/handlers.go +++ b/internal/broker/handlers.go @@ -16,61 +16,79 @@ package broker import ( "encoding/json" + "fmt" + "regexp" "strings" "time" "github.com/emitter-io/emitter/internal/message" + "github.com/emitter-io/emitter/internal/network/mqtt" + "github.com/emitter-io/emitter/internal/provider/contract" "github.com/emitter-io/emitter/internal/provider/logging" "github.com/emitter-io/emitter/internal/security" + "github.com/emitter-io/emitter/internal/security/hash" "github.com/kelindar/binary" ) const ( requestKeygen = 548658350 // hash("keygen") requestPresence = 3869262148 // hash("presence") + requestLink = 2667034312 // hash("link") requestMe = 2539734036 // hash("me") ) -// ------------------------------------------------------------------------------------ +var ( + shortcut = regexp.MustCompile("^[a-zA-Z0-9]{1,2}$") +) -// OnSubscribe is a handler for MQTT Subscribe events. -func (c *Conn) onSubscribe(mqttTopic []byte) *Error { +// ------------------------------------------------------------------------------------ - // Parse the channel - channel := security.ParseChannel(mqttTopic) - if channel.ChannelType == security.ChannelInvalid { - return ErrBadRequest - } +// Authorize attempts to authorize a channel with its key +func (c *Conn) authorize(channel *security.Channel, permission uint8) (contract.Contract, security.Key, bool) { // Attempt to parse the key key, err := c.service.Cipher.DecryptKey(channel.Key) if err != nil || key.IsExpired() { - return ErrUnauthorized + return nil, nil, false } // Attempt to fetch the contract using the key. Underneath, it's cached. contract, contractFound := c.service.contracts.Get(key.Contract()) - if !contractFound { - return ErrNotFound + if !contractFound || !contract.Validate(key) || !key.HasPermission(permission) || !key.ValidateChannel(channel) { + return nil, nil, false } - // Validate the contract - if !contract.Validate(key) { - return ErrUnauthorized - } + // Return the contract and the key + return contract, key, true +} - // Check if the key has the permission to read from here - if !key.HasPermission(security.AllowRead) { - return ErrUnauthorized +// ------------------------------------------------------------------------------------ + +// onConnect handles the connection authorization +func (c *Conn) onConnect(packet *mqtt.Connect) bool { + c.username = string(packet.Username) + return true +} + +// ------------------------------------------------------------------------------------ + +// OnSubscribe is a handler for MQTT Subscribe events. +func (c *Conn) onSubscribe(mqttTopic []byte) *Error { + + // Parse the channel + channel := security.ParseChannel(mqttTopic) + if channel.ChannelType == security.ChannelInvalid { + return ErrBadRequest } - // Check if the key has the permission for the required channel - if !key.ValidateChannel(channel) { + // Check the authorization and permissions + contract, key, allowed := c.authorize(channel, security.AllowRead) + if !allowed { return ErrUnauthorized } // Subscribe the client to the channel - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) c.Subscribe(ssid, channel.Channel) // In case of ttl, check the key provides the permission to store (soft permission) @@ -105,35 +123,14 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error { return ErrBadRequest } - // Attempt to parse the key - key, err := c.service.Cipher.DecryptKey(channel.Key) - if err != nil || key.IsExpired() { - return ErrUnauthorized - } - - // Attempt to fetch the contract using the key. Underneath, it's cached. - contract, contractFound := c.service.contracts.Get(key.Contract()) - if !contractFound { - return ErrNotFound - } - - // Validate the contract - if !contract.Validate(key) { - return ErrUnauthorized - } - - // Check if the key has the permission to read from here - if !key.HasPermission(security.AllowRead) { - return ErrUnauthorized - } - - // Check if the key has the permission for the required channel - if !key.ValidateChannel(channel) { + // Check the authorization and permissions + contract, key, allowed := c.authorize(channel, security.AllowRead) + if !allowed { return ErrUnauthorized } // Unsubscribe the client from the channel - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) c.Unsubscribe(ssid, channel.Channel) c.track(contract) return nil @@ -143,8 +140,13 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error { // OnPublish is a handler for MQTT Publish events. func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { + exclude := "" + if len(mqttTopic) <= 2 && c.links != nil { + mqttTopic = []byte(c.links[string(mqttTopic)]) + exclude = c.ID() + } - // Parse the channel + // Make sure we have a valid channel channel := security.ParseChannel(mqttTopic) if channel.ChannelType == security.ChannelInvalid { return ErrBadRequest @@ -161,36 +163,15 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { return nil } - // Attempt to parse the key - key, err := c.service.Cipher.DecryptKey(channel.Key) - if err != nil || key.IsExpired() { - return ErrUnauthorized - } - - // Attempt to fetch the contract using the key. Underneath, it's cached. - contract, contractFound := c.service.contracts.Get(key.Contract()) - if !contractFound { - return ErrNotFound - } - - // Validate the contract - if !contract.Validate(key) { - return ErrUnauthorized - } - - // Check if the key has the permission to write here - if !key.HasPermission(security.AllowWrite) { - return ErrUnauthorized - } - - // Check if the key has the permission for the required channel - if !key.ValidateChannel(channel) { + // Check the authorization and permissions + contract, key, allowed := c.authorize(channel, security.AllowWrite) + if !allowed { return ErrUnauthorized } // Create a new message msg := message.New( - message.NewSsid(key.Contract(), channel), + message.NewSsid(key.Contract(), channel.Query), channel.Channel, payload, ) @@ -202,7 +183,7 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { } // Iterate through all subscribers and send them the message - size := c.service.publish(msg) + size := c.service.publish(msg, exclude) // Write the monitoring information c.track(contract) @@ -241,6 +222,9 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b case requestMe: resp, ok = c.onMe() return + case requestLink: + resp, ok = c.onLink(payload) + return default: return } @@ -248,11 +232,88 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b // ------------------------------------------------------------------------------------ +// onLink handles a request to create a link. +func (c *Conn) onLink(payload []byte) (interface{}, bool) { + var request linkRequest + if err := json.Unmarshal(payload, &request); err != nil { + return ErrBadRequest, false + } + + // Check whether the name is a valid shortcut name + if !shortcut.Match([]byte(request.Name)) { + return ErrLinkInvalid, false + } + + // Make the channel from the request or try to make a private one + channel := security.MakeChannel(request.Key, request.Channel) + if request.Private { + channel = c.makePrivateChannel(request.Key, request.Channel) + } + + // Ensures that the channel requested is valid + if channel == nil || channel.ChannelType == security.ChannelInvalid { + return ErrBadRequest, false + } + + // Create the link with the name and set the full channel to it + c.links[request.Name] = channel.String() + + // If an auto-subscribe was requested and the key has read permissions, subscribe + if _, key, allowed := c.authorize(channel, security.AllowRead); allowed && request.Subscribe { + c.Subscribe(message.NewSsid(key.Contract(), channel.Query), channel.Channel) + } + + return &linkResponse{ + Status: 200, + Name: request.Name, + Channel: channel.SafeString(), + }, true +} + +// makePrivateChannel creates a private channel and an appropriate key. +func (c *Conn) makePrivateChannel(chanKey, chanName string) *security.Channel { + channel := security.MakeChannel(chanKey, chanName) + if channel.ChannelType != security.ChannelStatic { + return nil + } + + // Make sure we can actually extend it + _, key, allowed := c.authorize(channel, security.AllowExtend) + if !allowed { + return nil + } + + // Create a new key for the private link + target := fmt.Sprintf("%s%s/", channel.Channel, c.ID()) + if err := key.SetTarget(target); err != nil { + return nil + } + + // Encrypt the key for storing + encryptedKey, err := c.service.Cipher.EncryptKey(key) + if err != nil { + return nil + } + + // Create the private channel + channel.Channel = []byte(target) + channel.Query = append(channel.Query, hash.Of([]byte(c.ID()))) + channel.Key = []byte(encryptedKey) + return channel +} + +// ------------------------------------------------------------------------------------ + // OnMe is a handler that returns information to the connection. func (c *Conn) onMe() (interface{}, bool) { - // Success, return the response + links := make(map[string]string) + for k, v := range c.links { + links[k] = security.ParseChannel([]byte(v)).SafeString() + } + return &meResponse{ - ID: c.ID(), + ID: c.ID(), + Links: links, }, true } @@ -407,7 +468,7 @@ func (c *Conn) onPresence(payload []byte) (interface{}, bool) { } // Create the ssid for the presence - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) // Check if the client is interested in subscribing/unsubscribing from changes. if msg.Changes { diff --git a/internal/broker/handlers_dto.go b/internal/broker/handlers_dto.go index 913bc068..f50412b0 100644 --- a/internal/broker/handlers_dto.go +++ b/internal/broker/handlers_dto.go @@ -35,15 +35,16 @@ func (e *Error) Error() string { return e.Message } // Represents a set of errors used in the handlers. var ( - ErrBadRequest = &Error{Status: 400, Message: "The request was invalid or cannot be otherwise served."} - ErrUnauthorized = &Error{Status: 401, Message: "The security key provided is not authorized to perform this operation."} - ErrPaymentRequired = &Error{Status: 402, Message: "The request can not be served, as the payment is required to proceed."} - ErrForbidden = &Error{Status: 403, Message: "The request is understood, but it has been refused or access is not allowed."} - ErrNotFound = &Error{Status: 404, Message: "The resource requested does not exist."} - ErrServerError = &Error{Status: 500, Message: "An unexpected condition was encountered and no more specific message is suitable."} - ErrNotImplemented = &Error{Status: 501, Message: "The server either does not recognize the request method, or it lacks the ability to fulfill the request."} - ErrTargetInvalid = &Error{Status: 400, Message: "Channel should end with `/` for strict types or `/#/` for wildcards."} - ErrTargetTooLong = &Error{Status: 400, Message: "Channel can not have more than 23 parts."} + ErrBadRequest = &Error{Status: 400, Message: "the request was invalid or cannot be otherwise served"} + ErrUnauthorized = &Error{Status: 401, Message: "the security key provided is not authorized to perform this operation"} + ErrPaymentRequired = &Error{Status: 402, Message: "the request can not be served, as the payment is required to proceed"} + ErrForbidden = &Error{Status: 403, Message: "the request is understood, but it has been refused or access is not allowed"} + ErrNotFound = &Error{Status: 404, Message: "the resource requested does not exist"} + ErrServerError = &Error{Status: 500, Message: "an unexpected condition was encountered and no more specific message is suitable"} + ErrNotImplemented = &Error{Status: 501, Message: "the server either does not recognize the request method, or it lacks the ability to fulfill the request"} + ErrTargetInvalid = &Error{Status: 400, Message: "channel should end with `/` for strict types or `/#/` for wildcards"} + ErrTargetTooLong = &Error{Status: 400, Message: "channel can not have more than 23 parts."} + ErrLinkInvalid = &Error{Status: 400, Message: "the link must be an alphanumeric string of 1 or 2 characters"} ) // ------------------------------------------------------------------------------------ @@ -63,7 +64,7 @@ func (m *keyGenRequest) expires() time.Time { return time.Now().Add(time.Duration(m.TTL) * time.Second).UTC() } -func (m *keyGenRequest) access() uint32 { +func (m *keyGenRequest) access() uint8 { required := security.AllowNone for i := 0; i < len(m.Type); i++ { @@ -78,6 +79,10 @@ func (m *keyGenRequest) access() uint32 { required |= security.AllowLoad case 'p': required |= security.AllowPresence + case 'e': + required |= security.AllowExtend + case 'x': + required |= security.AllowExecute } } @@ -86,12 +91,6 @@ func (m *keyGenRequest) access() uint32 { // ------------------------------------------------------------------------------------ -type meResponse struct { - ID string `json:"id"` -} - -// ------------------------------------------------------------------------------------ - type keyGenResponse struct { Status int `json:"status"` Key string `json:"key"` @@ -100,6 +99,31 @@ type keyGenResponse struct { // ------------------------------------------------------------------------------------ +type linkRequest struct { + Name string `json:"name"` // The name of the shortcut, max 2 characters. + Key string `json:"key"` // The key for the channel. + Channel string `json:"channel"` // The channel name for the shortcut. + Subscribe bool `json:"subscribe"` // Specifies whether the broker should auto-subscribe. + Private bool `json:"private"` // Specifies whether the broker should generate a private link. +} + +// ------------------------------------------------------------------------------------ + +type linkResponse struct { + Status int `json:"status"` // The status of the response. + Name string `json:"name,omitempty"` // The name of the shortcut, max 2 characters. + Channel string `json:"channel,omitempty"` // The channel which was registered. +} + +// ------------------------------------------------------------------------------------ + +type meResponse struct { + ID string `json:"id"` // The private ID of the connection. + Links map[string]string `json:"links,omitempty"` // The set of pre-defined channels. +} + +// ------------------------------------------------------------------------------------ + type presenceRequest struct { Key string `json:"key"` // The channel key for this request. Channel string `json:"channel"` // The target channel for this request. diff --git a/internal/broker/handlers_test.go b/internal/broker/handlers_test.go index 30bdb274..0f36eea4 100644 --- a/internal/broker/handlers_test.go +++ b/internal/broker/handlers_test.go @@ -15,6 +15,56 @@ import ( "github.com/stretchr/testify/mock" ) +func TestHandlers_onLink(t *testing.T) { + tests := []struct { + packet string + channel string + success bool + }{ + { + packet: `{ "name": "AB", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true, "subscribe": true }`, + channel: "a/b/c/", + success: true, + }, + { + packet: `{ "name": "AB", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/"}`, + channel: "a/b/c/", + success: true, + }, + {packet: `{ "name": "ABC", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true, "subscribe": true }`}, + {packet: `{ "name": "", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true, "subscribe": true }`}, + {packet: `{"key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true, "subscribe": true }`}, + {packet: `{ "name": "AB", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "---", "private": true, "subscribe": true }`}, + {packet: `{ "name": "AB", "key": "xxx", "channel": "a/b/c/", "private": true, "subscribe": true }`}, + } + + for _, tc := range tests { + t.Run(tc.packet, func(*testing.T) { + provider := secmock.NewContractProvider() + contract := new(secmock.Contract) + contract.On("Validate", mock.Anything).Return(true) + provider.On("Get", mock.Anything).Return(contract, true) + license, _ := security.ParseLicense(testLicense) + s := &Service{ + contracts: provider, + subscriptions: message.NewTrie(), + License: license, + presence: make(chan *presenceNotify, 100), + } + + s.Cipher, _ = s.License.Cipher() + conn := netmock.NewConn() + nc := s.newConn(conn.Client) + + resp, ok := nc.onLink([]byte(tc.packet)) + assert.Equal(t, tc.success, ok) + if tc.success { + assert.Contains(t, resp.(*linkResponse).Channel, tc.channel) + } + }) + } +} + func TestHandlers_onMe(t *testing.T) { license, _ := security.ParseLicense(testLicense) s := &Service{ @@ -24,10 +74,12 @@ func TestHandlers_onMe(t *testing.T) { conn := netmock.NewConn() nc := s.newConn(conn.Client) + nc.links["0"] = "key/a/b/c/" resp, success := nc.onMe() meResp := resp.(*meResponse) - assert.Equal(t, success, true, success) + assert.True(t, success) + assert.Equal(t, "a/b/c/", meResp.Links["0"]) assert.NotNil(t, resp) assert.NotZero(t, len(meResp.ID)) } @@ -86,9 +138,9 @@ func TestHandlers_onSubscribeUnsubscribe(t *testing.T) { { channel: "0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/", subCount: 0, - subErr: ErrNotFound, + subErr: ErrUnauthorized, unsubCount: 0, - unsubErr: ErrNotFound, + unsubErr: ErrUnauthorized, contractValid: true, contractFound: false, msg: "Contract not found case", @@ -150,7 +202,7 @@ func TestHandlers_onSubscribeUnsubscribe(t *testing.T) { // Search for the ssid. channel := security.ParseChannel([]byte(tc.channel)) key, _ := s.Cipher.DecryptKey(channel.Key) - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) subscribers := s.subscriptions.Lookup(ssid) assert.Equal(t, tc.subCount, len(subscribers)) @@ -210,7 +262,7 @@ func TestHandlers_onPublish(t *testing.T) { { channel: "0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/", payload: "test", - err: ErrNotFound, + err: ErrUnauthorized, contractValid: true, contractFound: false, msg: "Contract not found case", @@ -498,6 +550,11 @@ func TestHandlers_onEmitterRequest(t *testing.T) { query: []uint32{requestMe}, success: true, }, + { + channel: "link", + query: []uint32{requestLink}, + success: false, + }, } for _, tc := range tests { diff --git a/internal/broker/query.go b/internal/broker/query.go index bdd89bbf..c7f0c27b 100644 --- a/internal/broker/query.go +++ b/internal/broker/query.go @@ -164,7 +164,7 @@ func (c *QueryManager) Query(query string, payload []byte) (message.Awaiter, err message.Ssid{idSystem, idQuery, awaiter.id}, []byte(channel), payload, - )) + ), "") return awaiter, nil } diff --git a/internal/broker/service.go b/internal/broker/service.go index 57cc2804..3306f63f 100644 --- a/internal/broker/service.go +++ b/internal/broker/service.go @@ -250,7 +250,7 @@ func (s *Service) notifyPresenceChange() { return case notif := <-s.presence: if encoded, ok := notif.Encode(); ok { - s.publish(message.New(notif.Ssid, channel, encoded)) + s.publish(message.New(notif.Ssid, channel, encoded), "") } } } @@ -364,7 +364,7 @@ func (s *Service) onHTTPPresence(w http.ResponseWriter, r *http.Request) { } // Create the ssid for the presence - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) now := time.Now().UTC().Unix() who := getAllPresence(s, ssid) resp, err := json.Marshal(&presenceResponse{ @@ -436,10 +436,12 @@ func (s *Service) Survey(query string, payload []byte) (message.Awaiter, error) } // Publish publishes a message to everyone and returns the number of outgoing bytes written. -func (s *Service) publish(m *message.Message) (n int64) { +func (s *Service) publish(m *message.Message, exclude string) (n int64) { size := m.Size() for _, subscriber := range s.subscriptions.Lookup(m.Ssid()) { - subscriber.Send(m) + if subscriber.ID() != exclude { + subscriber.Send(m) + } // Increment the egress size only for direct subscribers if subscriber.Type() == message.SubscriberDirect { @@ -455,10 +457,10 @@ func (s *Service) selfPublish(channelName string, payload []byte) { channel := security.ParseChannel([]byte("emitter/" + channelName)) if channel.ChannelType == security.ChannelStatic { s.publish(message.New( - message.NewSsid(s.License.Contract, channel), + message.NewSsid(s.License.Contract, channel.Query), channel.Channel, payload, - )) + ), "") } } diff --git a/internal/broker/service_test.go b/internal/broker/service_test.go index a19ada41..98a65510 100644 --- a/internal/broker/service_test.go +++ b/internal/broker/service_test.go @@ -233,6 +233,32 @@ func TestPubsub(t *testing.T) { assert.Equal(t, mqtt.TypeOfUnsuback, pkt.Type()) } + { // Create a private link + msg := mqtt.Publish{ + Header: &mqtt.StaticHeader{QOS: 0}, + Topic: []byte("emitter/link/"), + Payload: []byte(`{ "name": "hi", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true }`), + } + _, err := msg.EncodeTo(cli) + assert.NoError(t, err) + } + + { // Read the link response + pkt, err := mqtt.DecodePacket(cli) + assert.NoError(t, err) + assert.Equal(t, mqtt.TypeOfPublish, pkt.Type()) + } + + { // Publish a message to a link + msg := mqtt.Publish{ + Header: &mqtt.StaticHeader{QOS: 0}, + Topic: []byte("hi"), + Payload: []byte("hello world"), + } + _, err := msg.EncodeTo(cli) + assert.NoError(t, err) + } + { // Disconnect from the broker disconnect := mqtt.Disconnect{} n, err := disconnect.EncodeTo(cli) diff --git a/internal/message/sub.go b/internal/message/sub.go index ea0c878f..ab9b47c5 100644 --- a/internal/message/sub.go +++ b/internal/message/sub.go @@ -20,8 +20,6 @@ import ( "sync" "time" "unsafe" - - "github.com/emitter-io/emitter/internal/security" ) // Various constant parts of the SSID. @@ -40,10 +38,10 @@ var Query = Ssid{system, query} type Ssid []uint32 // NewSsid creates a new SSID. -func NewSsid(contract uint32, c *security.Channel) Ssid { - ssid := make([]uint32, 0, len(c.Query)+1) +func NewSsid(contract uint32, query []uint32) Ssid { + ssid := make([]uint32, 0, len(query)+1) ssid = append(ssid, uint32(contract)) - ssid = append(ssid, c.Query...) + ssid = append(ssid, query...) return ssid } diff --git a/internal/message/sub_test.go b/internal/message/sub_test.go index beeecc39..77ce8735 100644 --- a/internal/message/sub_test.go +++ b/internal/message/sub_test.go @@ -8,7 +8,7 @@ import ( ) func BenchmarkSsidEncode(b *testing.B) { - ssid := NewSsid(0, &security.Channel{Query: []uint32{56498455, 44565213, 46512350, 18204498}}) + ssid := NewSsid(0, []uint32{56498455, 44565213, 46512350, 18204498}) b.ReportAllocs() b.ResetTimer() @@ -32,7 +32,7 @@ func TestSsid(t *testing.T) { ChannelType: security.ChannelStatic, } - ssid := NewSsid(0, &c) + ssid := NewSsid(0, c.Query) assert.Equal(t, uint32(0), ssid.Contract()) assert.Equal(t, uint32(0x2c), ssid.GetHashCode()) } @@ -53,7 +53,7 @@ func TestSsidEncode(t *testing.T) { } for _, tc := range tests { - ssid := NewSsid(0, &security.Channel{Query: tc.ssid}) + ssid := NewSsid(0, tc.ssid) assert.Equal(t, tc.expected, ssid.Encode()) } } diff --git a/internal/security/channel.go b/internal/security/channel.go index 8f866f7b..4295b7d2 100644 --- a/internal/security/channel.go +++ b/internal/security/channel.go @@ -15,6 +15,7 @@ package security import ( + "fmt" "strconv" "time" "unsafe" @@ -75,6 +76,32 @@ func (c *Channel) Window() (time.Time, time.Time) { return toUnix(u0), toUnix(u1) } +// SafeString returns a string representation of the channel without the key. +func (c *Channel) SafeString() string { + text := string(c.Channel) + if len(c.Options) == 0 { + return text + } + + text += "?" + for i, v := range c.Options { + if i > 0 { + text += "&" + } + + text += v.Key + "=" + v.Value + } + return text +} + +// String returns a string representation of the channel. +func (c *Channel) String() string { + text := string(c.Key) + text += "/" + text += c.SafeString() + return text +} + // Converts the time to Unix Time with validation. func toUnix(t int64) time.Time { if t == 0 || t < MinTime || t > MaxTime { @@ -97,6 +124,11 @@ func (c *Channel) getOption(name string) (int64, bool) { return 0, false } +// MakeChannel attempts to parse the channel from the key and channel strings. +func MakeChannel(key, channel string) *Channel { + return ParseChannel([]byte(fmt.Sprintf("%s/%s", key, channel))) +} + // ParseChannel attempts to parse the channel from the underlying slice. func ParseChannel(text []byte) (channel *Channel) { channel = new(Channel) diff --git a/internal/security/channel_test.go b/internal/security/channel_test.go index f278865e..715ba1e7 100644 --- a/internal/security/channel_test.go +++ b/internal/security/channel_test.go @@ -195,3 +195,34 @@ func TestGetChannelTarget(t *testing.T) { assert.Equal(t, tc.target, target) } } + +func TestMakeChannel(t *testing.T) { + tests := []struct { + key string + channel string + }{ + {key: "key1", channel: "emitter/a/"}, + } + + for _, tc := range tests { + channel := MakeChannel(tc.key, tc.channel) + assert.Equal(t, tc.key, string(channel.Key)) + assert.Equal(t, tc.channel, string(channel.Channel)) + } +} + +func TestChannelString(t *testing.T) { + tests := []struct { + channel string + }{ + {channel: "emitter/a/?last=42&abc=9"}, + {channel: "emitter/a/?last=1200"}, + {channel: "emitter/a/?last=1200a"}, + {channel: "emitter/a/"}, + } + + for _, tc := range tests { + channel := ParseChannel([]byte(tc.channel)) + assert.Equal(t, tc.channel, channel.String()) + } +} diff --git a/internal/security/crypto.go b/internal/security/crypto.go index 7d017d39..5f86be3a 100644 --- a/internal/security/crypto.go +++ b/internal/security/crypto.go @@ -127,7 +127,7 @@ func (c *Cipher) EncryptKey(k Key) (string, error) { } // GenerateKey generates a new key. -func (c *Cipher) GenerateKey(masterKey Key, channel string, permissions uint32, expires time.Time, maxRandSalt int16) (string, error) { +func (c *Cipher) GenerateKey(masterKey Key, channel string, permissions uint8, expires time.Time, maxRandSalt int16) (string, error) { if maxRandSalt <= 0 { maxRandSalt = math.MaxInt16 } diff --git a/internal/security/hash/murmur_test.go b/internal/security/hash/murmur_test.go index 447727e8..6beaedb5 100644 --- a/internal/security/hash/murmur_test.go +++ b/internal/security/hash/murmur_test.go @@ -36,6 +36,11 @@ func TestMeHash(t *testing.T) { assert.Equal(t, uint32(2539734036), h) } +func TestLinkHash(t *testing.T) { + h := Of([]byte("link")) + assert.Equal(t, uint32(2667034312), h) +} + func TestGetHash(t *testing.T) { h := Of([]byte("+")) if h != 1815237614 { diff --git a/internal/security/key.go b/internal/security/key.go index a4046cf2..d416e361 100644 --- a/internal/security/key.go +++ b/internal/security/key.go @@ -24,13 +24,15 @@ import ( // Access types for a security key. const ( - AllowNone = uint32(0) // Key has no privileges. - AllowMaster = uint32(1 << 0) // Key should be allowed to generate other keys. - AllowRead = uint32(1 << 1) // Key should be allowed to subscribe to the target channel. - AllowWrite = uint32(1 << 2) // Key should be allowed to publish to the target channel. - AllowStore = uint32(1 << 3) // Key should be allowed to write to the message history of the target channel. - AllowLoad = uint32(1 << 4) // Key should be allowed to write to read the message history of the target channel. - AllowPresence = uint32(1 << 5) // Key should be allowed to query the presence on the target channel. + AllowNone = uint8(0) // Key has no privileges. + AllowMaster = uint8(1 << 0) // Key should be allowed to generate other keys. + AllowRead = uint8(1 << 1) // Key should be allowed to subscribe to the target channel. + AllowWrite = uint8(1 << 2) // Key should be allowed to publish to the target channel. + AllowStore = uint8(1 << 3) // Key should be allowed to write to the message history of the target channel. + AllowLoad = uint8(1 << 4) // Key should be allowed to write to read the message history of the target channel. + AllowPresence = uint8(1 << 5) // Key should be allowed to query the presence on the target channel. + AllowExtend = uint8(1 << 6) // Key should be allowed to create sub-channels by extending an existing one. + AllowExecute = uint8(1 << 7) // Key should be allowed to execute code. (RESERVED) AllowReadWrite = AllowRead | AllowWrite // Key should be allowed to read and write to the target channel. AllowStoreLoad = AllowStore | AllowLoad // Key should be allowed to read and write the message history. ) @@ -98,13 +100,13 @@ func (k Key) SetSignature(value uint32) { } // Permissions gets the permission flags. -func (k Key) Permissions() uint32 { - return uint32(k[15]) +func (k Key) Permissions() uint8 { + return k[15] } // SetPermissions sets the permission flags. -func (k Key) SetPermissions(value uint32) { - k[15] = byte(value) +func (k Key) SetPermissions(value uint8) { + k[15] = value } // ValidateChannel validates the channel string. @@ -251,7 +253,7 @@ func (k Key) IsMaster() bool { } // HasPermission check whether the key provides some permission. -func (k Key) HasPermission(flag uint32) bool { +func (k Key) HasPermission(flag uint8) bool { p := k.Permissions() return (p & flag) == flag }