From ce9a28df5703f5c30b9ad7592841c78ff43190e1 Mon Sep 17 00:00:00 2001 From: Roman Date: Sat, 26 Jan 2019 23:10:54 +0800 Subject: [PATCH 1/6] simplified client-server setup --- internal/broker/conn.go | 9 +- internal/broker/handlers.go | 170 +++++++++++++++----------- internal/broker/handlers_dto.go | 3 +- internal/broker/handlers_test.go | 58 ++++++++- internal/broker/query.go | 2 +- internal/broker/service.go | 10 +- internal/security/hash/murmur_test.go | 5 + internal/security/key.go | 4 +- internal/security/password.go | 45 +++++++ internal/security/password_test.go | 41 +++++++ 10 files changed, 262 insertions(+), 85 deletions(-) create mode 100644 internal/security/password.go create mode 100644 internal/security/password_test.go diff --git a/internal/broker/conn.go b/internal/broker/conn.go index 3e1cc91d..7e80f578 100644 --- a/internal/broker/conn.go +++ b/internal/broker/conn.go @@ -41,6 +41,7 @@ type Conn struct { username string // The username provided by the client during MQTT connect. luid security.ID // The locally unique id of the connection. guid string // The globally unique id of the connection. + dial string // The pre-authorized channel with a valid key used for dial. service *Service // The service for this connection. subs *message.Counters // The subscriptions for this connection. measurer stats.Measurer // The measurer to use for monitoring. @@ -137,11 +138,13 @@ func (c *Conn) onReceive(msg mqtt.Message) error { // We got an attempt to connect to MQTT. case mqtt.TypeOfConnect: - packet := msg.(*mqtt.Connect) - c.username = string(packet.Username) + var result uint8 + if !c.onConnect(msg.(*mqtt.Connect)) { + result = 0x05 // Unauthorized + } // Write the ack - ack := mqtt.Connack{ReturnCode: 0x00} + ack := mqtt.Connack{ReturnCode: result} if _, err := ack.EncodeTo(c.socket); err != nil { return err } diff --git a/internal/broker/handlers.go b/internal/broker/handlers.go index 1e39feec..26b97293 100644 --- a/internal/broker/handlers.go +++ b/internal/broker/handlers.go @@ -16,10 +16,14 @@ package broker import ( "encoding/json" + "fmt" + "github.com/emitter-io/emitter/internal/security/hash" "strings" "time" "github.com/emitter-io/emitter/internal/message" + "github.com/emitter-io/emitter/internal/network/mqtt" + "github.com/emitter-io/emitter/internal/provider/contract" "github.com/emitter-io/emitter/internal/provider/logging" "github.com/emitter-io/emitter/internal/security" "github.com/kelindar/binary" @@ -33,39 +37,100 @@ const ( // ------------------------------------------------------------------------------------ -// OnSubscribe is a handler for MQTT Subscribe events. -func (c *Conn) onSubscribe(mqttTopic []byte) *Error { - - // Parse the channel - channel := security.ParseChannel(mqttTopic) - if channel.ChannelType == security.ChannelInvalid { - return ErrBadRequest - } +// Authorize attempts to authorize a channel with its key +func (c *Conn) authorize(channel *security.Channel, permission uint32) (contract.Contract, security.Key, bool) { // Attempt to parse the key key, err := c.service.Cipher.DecryptKey(channel.Key) if err != nil || key.IsExpired() { - return ErrUnauthorized + return nil, nil, false } // Attempt to fetch the contract using the key. Underneath, it's cached. contract, contractFound := c.service.contracts.Get(key.Contract()) - if !contractFound { - return ErrNotFound + if !contractFound || !contract.Validate(key) || !key.HasPermission(permission) || !key.ValidateChannel(channel) { + return nil, nil, false } - // Validate the contract - if !contract.Validate(key) { - return ErrUnauthorized + // Return the contract and the key + return contract, key, true +} + +// ------------------------------------------------------------------------------------ + +// onConnect handles the connection authorization +func (c *Conn) onConnect(packet *mqtt.Connect) bool { + c.username = string(packet.Username) + + // If there's no password provided, we're done + if len(packet.Password) == 0 { + return true } - // Check if the key has the permission to read from here - if !key.HasPermission(security.AllowRead) { - return ErrUnauthorized + // If the password was provided, evalute our scheme. The format should be: + // dial://{key}/{channel}/ + scheme, channel := security.ParsePassword(string(packet.Password)) + if len(scheme) == 0 || channel == nil || channel.ChannelType == security.ChannelInvalid { + return false + } + + // If it's a dial, we need to append the connection ID and generate a new key + if scheme == "dial" { + if dial, ok := c.dialAndSubscribe(channel); ok { + c.dial = dial // Keep it as string, we want to copy later + return true + } + } + + return false +} + +// creates a new channel for the dial and subscribes to it if allowed +func (c *Conn) dialAndSubscribe(channel *security.Channel) (string, bool) { + + // Check the authorization and permissions + _, key, allowed := c.authorize(channel, security.AllowAny) + if !allowed { + return "", false + } + + // Create a new key for the dial + target := fmt.Sprintf("%s%s/", channel.Channel, c.ID()) + if err := key.SetTarget(target); err != nil { + return "", false + } + + // Encrypt the key for storing + encryptedKey, err := c.service.Cipher.EncryptKey(key) + if err != nil { + return "", false + } + + // Auto-subscribe to the dial if allowed + if key.HasPermission(security.AllowRead) { + channel.Channel = []byte(target) + channel.Query = append(channel.Query, hash.Of([]byte(c.ID()))) + ssid := message.NewSsid(key.Contract(), channel) + c.Subscribe(ssid, channel.Channel) + } + + return fmt.Sprintf("%s/%s", encryptedKey, target), true +} + +// ------------------------------------------------------------------------------------ + +// OnSubscribe is a handler for MQTT Subscribe events. +func (c *Conn) onSubscribe(mqttTopic []byte) *Error { + + // Parse the channel + channel := security.ParseChannel(mqttTopic) + if channel.ChannelType == security.ChannelInvalid { + return ErrBadRequest } - // Check if the key has the permission for the required channel - if !key.ValidateChannel(channel) { + // Check the authorization and permissions + contract, key, allowed := c.authorize(channel, security.AllowRead) + if !allowed { return ErrUnauthorized } @@ -105,30 +170,9 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error { return ErrBadRequest } - // Attempt to parse the key - key, err := c.service.Cipher.DecryptKey(channel.Key) - if err != nil || key.IsExpired() { - return ErrUnauthorized - } - - // Attempt to fetch the contract using the key. Underneath, it's cached. - contract, contractFound := c.service.contracts.Get(key.Contract()) - if !contractFound { - return ErrNotFound - } - - // Validate the contract - if !contract.Validate(key) { - return ErrUnauthorized - } - - // Check if the key has the permission to read from here - if !key.HasPermission(security.AllowRead) { - return ErrUnauthorized - } - - // Check if the key has the permission for the required channel - if !key.ValidateChannel(channel) { + // Check the authorization and permissions + contract, key, allowed := c.authorize(channel, security.AllowRead) + if !allowed { return ErrUnauthorized } @@ -143,8 +187,13 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error { // OnPublish is a handler for MQTT Publish events. func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { + exclude := "" + if len(mqttTopic) <= 1 { + mqttTopic = []byte(c.dial) + exclude = c.ID() + } - // Parse the channel + // Make sure we have a valid channel channel := security.ParseChannel(mqttTopic) if channel.ChannelType == security.ChannelInvalid { return ErrBadRequest @@ -161,30 +210,9 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { return nil } - // Attempt to parse the key - key, err := c.service.Cipher.DecryptKey(channel.Key) - if err != nil || key.IsExpired() { - return ErrUnauthorized - } - - // Attempt to fetch the contract using the key. Underneath, it's cached. - contract, contractFound := c.service.contracts.Get(key.Contract()) - if !contractFound { - return ErrNotFound - } - - // Validate the contract - if !contract.Validate(key) { - return ErrUnauthorized - } - - // Check if the key has the permission to write here - if !key.HasPermission(security.AllowWrite) { - return ErrUnauthorized - } - - // Check if the key has the permission for the required channel - if !key.ValidateChannel(channel) { + // Check the authorization and permissions + contract, key, allowed := c.authorize(channel, security.AllowWrite) + if !allowed { return ErrUnauthorized } @@ -202,7 +230,7 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { } // Iterate through all subscribers and send them the message - size := c.service.publish(msg) + size := c.service.publish(msg, exclude) // Write the monitoring information c.track(contract) @@ -250,9 +278,9 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b // OnMe is a handler that returns information to the connection. func (c *Conn) onMe() (interface{}, bool) { - // Success, return the response return &meResponse{ - ID: c.ID(), + ID: c.ID(), + Dial: string(c.dial), }, true } diff --git a/internal/broker/handlers_dto.go b/internal/broker/handlers_dto.go index 913bc068..0095f5e1 100644 --- a/internal/broker/handlers_dto.go +++ b/internal/broker/handlers_dto.go @@ -87,7 +87,8 @@ func (m *keyGenRequest) access() uint32 { // ------------------------------------------------------------------------------------ type meResponse struct { - ID string `json:"id"` + ID string `json:"id"` // The private ID of the connection, ret + Dial string `json:"dial,omitempty"` // The dial channel name } // ------------------------------------------------------------------------------------ diff --git a/internal/broker/handlers_test.go b/internal/broker/handlers_test.go index 30bdb274..65f016db 100644 --- a/internal/broker/handlers_test.go +++ b/internal/broker/handlers_test.go @@ -1,6 +1,7 @@ package broker import ( + "github.com/emitter-io/emitter/internal/network/mqtt" "testing" "github.com/emitter-io/emitter/internal/message" @@ -24,14 +25,63 @@ func TestHandlers_onMe(t *testing.T) { conn := netmock.NewConn() nc := s.newConn(conn.Client) + nc.dial = "a/b/c/" resp, success := nc.onMe() meResp := resp.(*meResponse) - assert.Equal(t, success, true, success) + assert.True(t, success) + assert.Equal(t, "a/b/c/", meResp.Dial) assert.NotNil(t, resp) assert.NotZero(t, len(meResp.ID)) } +func TestHandlers_onConnect(t *testing.T) { + license, _ := security.ParseLicense(testLicense) + tests := []struct { + password string + channel string + ok bool + }{ + {password: "", ok: true}, + {password: "dial://0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/", channel: "a/b/c/CONNECTION_ID/", ok: true}, + + {password: "dial://a/b/c/"}, + {password: "a/b/c/"}, + {password: "agsew350290"}, + {password: "1.2342/24/225"}, + {password: "fake://0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/"}, + } + + for _, tc := range tests { + t.Run(tc.password, func(*testing.T) { + provider := secmock.NewContractProvider() + contract := new(secmock.Contract) + contract.On("Validate", mock.Anything).Return(true) + contract.On("Stats").Return(usage.NewMeter(0)) + provider.On("Get", mock.Anything).Return(contract, true) + s := &Service{ + contracts: provider, + subscriptions: message.NewTrie(), + License: license, + presence: make(chan *presenceNotify, 100), + } + + conn := netmock.NewConn() + nc := s.newConn(conn.Client) + nc.guid = "CONNECTION_ID" + s.Cipher, _ = s.License.Cipher() + ok := nc.onConnect(&mqtt.Connect{ + Password: []byte(tc.password), + }) + + assert.Equal(t, tc.ok, ok, tc.password) + if tc.channel != "" { + assert.Contains(t, string(nc.dial), tc.channel) + } + }) + } +} + func TestHandlers_onSubscribeUnsubscribe(t *testing.T) { license, _ := security.ParseLicense(testLicense) tests := []struct { @@ -86,9 +136,9 @@ func TestHandlers_onSubscribeUnsubscribe(t *testing.T) { { channel: "0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/", subCount: 0, - subErr: ErrNotFound, + subErr: ErrUnauthorized, unsubCount: 0, - unsubErr: ErrNotFound, + unsubErr: ErrUnauthorized, contractValid: true, contractFound: false, msg: "Contract not found case", @@ -210,7 +260,7 @@ func TestHandlers_onPublish(t *testing.T) { { channel: "0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/", payload: "test", - err: ErrNotFound, + err: ErrUnauthorized, contractValid: true, contractFound: false, msg: "Contract not found case", diff --git a/internal/broker/query.go b/internal/broker/query.go index bdd89bbf..c7f0c27b 100644 --- a/internal/broker/query.go +++ b/internal/broker/query.go @@ -164,7 +164,7 @@ func (c *QueryManager) Query(query string, payload []byte) (message.Awaiter, err message.Ssid{idSystem, idQuery, awaiter.id}, []byte(channel), payload, - )) + ), "") return awaiter, nil } diff --git a/internal/broker/service.go b/internal/broker/service.go index 57cc2804..ba0ea085 100644 --- a/internal/broker/service.go +++ b/internal/broker/service.go @@ -250,7 +250,7 @@ func (s *Service) notifyPresenceChange() { return case notif := <-s.presence: if encoded, ok := notif.Encode(); ok { - s.publish(message.New(notif.Ssid, channel, encoded)) + s.publish(message.New(notif.Ssid, channel, encoded), "") } } } @@ -436,10 +436,12 @@ func (s *Service) Survey(query string, payload []byte) (message.Awaiter, error) } // Publish publishes a message to everyone and returns the number of outgoing bytes written. -func (s *Service) publish(m *message.Message) (n int64) { +func (s *Service) publish(m *message.Message, exclude string) (n int64) { size := m.Size() for _, subscriber := range s.subscriptions.Lookup(m.Ssid()) { - subscriber.Send(m) + if subscriber.ID() != exclude { + subscriber.Send(m) + } // Increment the egress size only for direct subscribers if subscriber.Type() == message.SubscriberDirect { @@ -458,7 +460,7 @@ func (s *Service) selfPublish(channelName string, payload []byte) { message.NewSsid(s.License.Contract, channel), channel.Channel, payload, - )) + ), "") } } diff --git a/internal/security/hash/murmur_test.go b/internal/security/hash/murmur_test.go index 447727e8..c8d7685a 100644 --- a/internal/security/hash/murmur_test.go +++ b/internal/security/hash/murmur_test.go @@ -36,6 +36,11 @@ func TestMeHash(t *testing.T) { assert.Equal(t, uint32(2539734036), h) } +func TestDialHash(t *testing.T) { + h := Of([]byte("dial")) + assert.Equal(t, uint32(1673593207), h) +} + func TestGetHash(t *testing.T) { h := Of([]byte("+")) if h != 1815237614 { diff --git a/internal/security/key.go b/internal/security/key.go index a4046cf2..ac14c8dd 100644 --- a/internal/security/key.go +++ b/internal/security/key.go @@ -16,6 +16,7 @@ package security import ( "errors" + "math" "strings" "time" @@ -33,6 +34,7 @@ const ( AllowPresence = uint32(1 << 5) // Key should be allowed to query the presence on the target channel. AllowReadWrite = AllowRead | AllowWrite // Key should be allowed to read and write to the target channel. AllowStoreLoad = AllowStore | AllowLoad // Key should be allowed to read and write the message history. + AllowAny = math.MaxUint32 ) // Key errors @@ -253,5 +255,5 @@ func (k Key) IsMaster() bool { // HasPermission check whether the key provides some permission. func (k Key) HasPermission(flag uint32) bool { p := k.Permissions() - return (p & flag) == flag + return flag == AllowAny || (p&flag) == flag } diff --git a/internal/security/password.go b/internal/security/password.go new file mode 100644 index 00000000..e02ad59a --- /dev/null +++ b/internal/security/password.go @@ -0,0 +1,45 @@ +/********************************************************************************** +* Copyright (c) 2009-2017 Misakai Ltd. +* This program is free software: you can redistribute it and/or modify it under the +* terms of the GNU Affero General Public License as published by the Free Software +* Foundation, either version 3 of the License, or(at your option) any later version. +* +* This program is distributed in the hope that it will be useful, but WITHOUT ANY +* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A +* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. +* +* You should have received a copy of the GNU Affero General Public License along +* with this program. If not, see. +************************************************************************************/ + +package security + +import ( + "regexp" +) + +// This is a strict password format +var passwordFormat = regexp.MustCompile(`^(dial)\:\/\/(.+)$`) + +// ParsePassword parses a pre-authorized channel key +func ParsePassword(password string) (string, *Channel) { + parts := passwordFormat.FindStringSubmatch(password) + if len(parts) != 3 { + return "", nil // Invalid channel + } + + // Get the scheme and channel and make sure they're valid + scheme := parts[1] + channel := ParseChannel([]byte(parts[2])) + if len(scheme) == 0 || channel == nil || channel.ChannelType == ChannelInvalid { + return "", nil + } + + // For dial to work, the channel must be static + if scheme == "dial" && channel.ChannelType == ChannelStatic { + return scheme, channel + } + + // Safe default + return "", nil +} diff --git a/internal/security/password_test.go b/internal/security/password_test.go new file mode 100644 index 00000000..3f7663fe --- /dev/null +++ b/internal/security/password_test.go @@ -0,0 +1,41 @@ +package security + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParsePassword(t *testing.T) { + tests := []struct { + input string + expect string + success bool + }{ + {input: "dial://emitter/a/?ttl=42&abc=9", expect: "a/", success: true}, + {input: "dial://emitter/a/?ttl=1200", expect: "a/", success: true}, + {input: "dial://emitter/a/?ttl=1200a", expect: "a/", success: true}, + {input: "dial://emitter/a/b/c/", expect: "a/b/c/", success: true}, + + {input: "auto:/test"}, + {input: "auto://"}, + {input: "auto://--------2345205982)@#(%&*@#//)//%"}, + {input: "dial://emitter/+/"}, + {input: "emitter/a/?ttl=42&abc=9"}, + {input: "emitter/a/?ttl=1200"}, + {input: "emitter/a/?ttl=1200a"}, + {input: "emitter/a/"}, + {input: "err://emitter/a/?ttl=42&abc=9"}, + {input: "err://emitter/a/?ttl=1200"}, + {input: "err://emitter/a/?ttl=1200a"}, + {input: "err://emitter/a/"}, + } + + for _, tc := range tests { + scheme, channel := ParsePassword(tc.input) + assert.Equal(t, scheme != "", tc.success, tc.input) + if tc.expect != "" { + assert.Equal(t, tc.expect, string(channel.Channel), tc.input) + } + } +} From 1fc6842cefa0c92cfbb48fc68a4d6f8bcfeb7031 Mon Sep 17 00:00:00 2001 From: Roman Date: Sun, 27 Jan 2019 09:55:46 +0800 Subject: [PATCH 2/6] added an explicit dial permission --- internal/broker/handlers.go | 28 ++++++++--------- internal/broker/handlers_dto.go | 6 +++- internal/broker/handlers_test.go | 4 +-- internal/security/crypto.go | 2 +- internal/security/{password.go => dial.go} | 18 +++++------ .../{password_test.go => dial_test.go} | 8 ++--- internal/security/key.go | 30 +++++++++---------- 7 files changed, 48 insertions(+), 48 deletions(-) rename internal/security/{password.go => dial.go} (78%) rename internal/security/{password_test.go => dial_test.go} (85%) diff --git a/internal/broker/handlers.go b/internal/broker/handlers.go index 26b97293..713a8957 100644 --- a/internal/broker/handlers.go +++ b/internal/broker/handlers.go @@ -38,7 +38,7 @@ const ( // ------------------------------------------------------------------------------------ // Authorize attempts to authorize a channel with its key -func (c *Conn) authorize(channel *security.Channel, permission uint32) (contract.Contract, security.Key, bool) { +func (c *Conn) authorize(channel *security.Channel, permission uint8) (contract.Contract, security.Key, bool) { // Attempt to parse the key key, err := c.service.Cipher.DecryptKey(channel.Key) @@ -67,29 +67,25 @@ func (c *Conn) onConnect(packet *mqtt.Connect) bool { return true } - // If the password was provided, evalute our scheme. The format should be: - // dial://{key}/{channel}/ - scheme, channel := security.ParsePassword(string(packet.Password)) - if len(scheme) == 0 || channel == nil || channel.ChannelType == security.ChannelInvalid { - return false - } - - // If it's a dial, we need to append the connection ID and generate a new key - if scheme == "dial" { - if dial, ok := c.dialAndSubscribe(channel); ok { - c.dial = dial // Keep it as string, we want to copy later - return true - } + // If the password was provided, try to parse the 'dial' string. The format + // should be: dial://{key}/{channel}/ + if dial, ok := c.dialAndSubscribe(string(packet.Password)); ok { + c.dial = dial // Keep it as string, we want to copy later + return true } return false } // creates a new channel for the dial and subscribes to it if allowed -func (c *Conn) dialAndSubscribe(channel *security.Channel) (string, bool) { +func (c *Conn) dialAndSubscribe(password string) (string, bool) { + channel, valid := security.ParseDial(string(password)) + if !valid { + return "", false + } // Check the authorization and permissions - _, key, allowed := c.authorize(channel, security.AllowAny) + _, key, allowed := c.authorize(channel, security.AllowDial) if !allowed { return "", false } diff --git a/internal/broker/handlers_dto.go b/internal/broker/handlers_dto.go index 0095f5e1..0c4db5e6 100644 --- a/internal/broker/handlers_dto.go +++ b/internal/broker/handlers_dto.go @@ -63,7 +63,7 @@ func (m *keyGenRequest) expires() time.Time { return time.Now().Add(time.Duration(m.TTL) * time.Second).UTC() } -func (m *keyGenRequest) access() uint32 { +func (m *keyGenRequest) access() uint8 { required := security.AllowNone for i := 0; i < len(m.Type); i++ { @@ -78,6 +78,10 @@ func (m *keyGenRequest) access() uint32 { required |= security.AllowLoad case 'p': required |= security.AllowPresence + case 'd': + required |= security.AllowDial + case 'x': + required |= security.AllowExecute } } diff --git a/internal/broker/handlers_test.go b/internal/broker/handlers_test.go index 65f016db..95034bf4 100644 --- a/internal/broker/handlers_test.go +++ b/internal/broker/handlers_test.go @@ -43,13 +43,13 @@ func TestHandlers_onConnect(t *testing.T) { ok bool }{ {password: "", ok: true}, - {password: "dial://0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/", channel: "a/b/c/CONNECTION_ID/", ok: true}, + {password: "dial://k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm/a/b/c/", channel: "a/b/c/CONNECTION_ID/", ok: true}, {password: "dial://a/b/c/"}, {password: "a/b/c/"}, {password: "agsew350290"}, {password: "1.2342/24/225"}, - {password: "fake://0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/"}, + {password: "fake://k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm/a/b/c/"}, } for _, tc := range tests { diff --git a/internal/security/crypto.go b/internal/security/crypto.go index 7d017d39..5f86be3a 100644 --- a/internal/security/crypto.go +++ b/internal/security/crypto.go @@ -127,7 +127,7 @@ func (c *Cipher) EncryptKey(k Key) (string, error) { } // GenerateKey generates a new key. -func (c *Cipher) GenerateKey(masterKey Key, channel string, permissions uint32, expires time.Time, maxRandSalt int16) (string, error) { +func (c *Cipher) GenerateKey(masterKey Key, channel string, permissions uint8, expires time.Time, maxRandSalt int16) (string, error) { if maxRandSalt <= 0 { maxRandSalt = math.MaxInt16 } diff --git a/internal/security/password.go b/internal/security/dial.go similarity index 78% rename from internal/security/password.go rename to internal/security/dial.go index e02ad59a..b802f117 100644 --- a/internal/security/password.go +++ b/internal/security/dial.go @@ -18,28 +18,28 @@ import ( "regexp" ) -// This is a strict password format -var passwordFormat = regexp.MustCompile(`^(dial)\:\/\/(.+)$`) +// This is a strict format for the dial string +var dialFormat = regexp.MustCompile(`^(dial)\:\/\/(.+)$`) -// ParsePassword parses a pre-authorized channel key -func ParsePassword(password string) (string, *Channel) { - parts := passwordFormat.FindStringSubmatch(password) +// ParseDial parses a pre-authorized channel key +func ParseDial(password string) (*Channel, bool) { + parts := dialFormat.FindStringSubmatch(password) if len(parts) != 3 { - return "", nil // Invalid channel + return nil, false // Invalid channel } // Get the scheme and channel and make sure they're valid scheme := parts[1] channel := ParseChannel([]byte(parts[2])) if len(scheme) == 0 || channel == nil || channel.ChannelType == ChannelInvalid { - return "", nil + return nil, false } // For dial to work, the channel must be static if scheme == "dial" && channel.ChannelType == ChannelStatic { - return scheme, channel + return channel, true } // Safe default - return "", nil + return nil, false } diff --git a/internal/security/password_test.go b/internal/security/dial_test.go similarity index 85% rename from internal/security/password_test.go rename to internal/security/dial_test.go index 3f7663fe..5e7c66d2 100644 --- a/internal/security/password_test.go +++ b/internal/security/dial_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestParsePassword(t *testing.T) { +func TestParseDial(t *testing.T) { tests := []struct { input string expect string @@ -32,9 +32,9 @@ func TestParsePassword(t *testing.T) { } for _, tc := range tests { - scheme, channel := ParsePassword(tc.input) - assert.Equal(t, scheme != "", tc.success, tc.input) - if tc.expect != "" { + channel, ok := ParseDial(tc.input) + assert.Equal(t, ok, tc.success, tc.input) + if ok && tc.expect != "" { assert.Equal(t, tc.expect, string(channel.Channel), tc.input) } } diff --git a/internal/security/key.go b/internal/security/key.go index ac14c8dd..78326e39 100644 --- a/internal/security/key.go +++ b/internal/security/key.go @@ -16,7 +16,6 @@ package security import ( "errors" - "math" "strings" "time" @@ -25,16 +24,17 @@ import ( // Access types for a security key. const ( - AllowNone = uint32(0) // Key has no privileges. - AllowMaster = uint32(1 << 0) // Key should be allowed to generate other keys. - AllowRead = uint32(1 << 1) // Key should be allowed to subscribe to the target channel. - AllowWrite = uint32(1 << 2) // Key should be allowed to publish to the target channel. - AllowStore = uint32(1 << 3) // Key should be allowed to write to the message history of the target channel. - AllowLoad = uint32(1 << 4) // Key should be allowed to write to read the message history of the target channel. - AllowPresence = uint32(1 << 5) // Key should be allowed to query the presence on the target channel. + AllowNone = uint8(0) // Key has no privileges. + AllowMaster = uint8(1 << 0) // Key should be allowed to generate other keys. + AllowRead = uint8(1 << 1) // Key should be allowed to subscribe to the target channel. + AllowWrite = uint8(1 << 2) // Key should be allowed to publish to the target channel. + AllowStore = uint8(1 << 3) // Key should be allowed to write to the message history of the target channel. + AllowLoad = uint8(1 << 4) // Key should be allowed to write to read the message history of the target channel. + AllowPresence = uint8(1 << 5) // Key should be allowed to query the presence on the target channel. + AllowDial = uint8(1 << 6) // Key should be allowed to create a 'dial' sub-channel. + AllowExecute = uint8(1 << 7) // Key should be allowed to execute code. (RESERVED) AllowReadWrite = AllowRead | AllowWrite // Key should be allowed to read and write to the target channel. AllowStoreLoad = AllowStore | AllowLoad // Key should be allowed to read and write the message history. - AllowAny = math.MaxUint32 ) // Key errors @@ -100,13 +100,13 @@ func (k Key) SetSignature(value uint32) { } // Permissions gets the permission flags. -func (k Key) Permissions() uint32 { - return uint32(k[15]) +func (k Key) Permissions() uint8 { + return k[15] } // SetPermissions sets the permission flags. -func (k Key) SetPermissions(value uint32) { - k[15] = byte(value) +func (k Key) SetPermissions(value uint8) { + k[15] = value } // ValidateChannel validates the channel string. @@ -253,7 +253,7 @@ func (k Key) IsMaster() bool { } // HasPermission check whether the key provides some permission. -func (k Key) HasPermission(flag uint32) bool { +func (k Key) HasPermission(flag uint8) bool { p := k.Permissions() - return flag == AllowAny || (p&flag) == flag + return (p & flag) == flag } From 4ffbe9e8b235863fd76ad3e782b62d626d4ca3ad Mon Sep 17 00:00:00 2001 From: Roman Date: Sun, 27 Jan 2019 10:06:15 +0800 Subject: [PATCH 3/6] do not send back the dial key --- internal/broker/handlers.go | 5 +++-- internal/broker/handlers_test.go | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/broker/handlers.go b/internal/broker/handlers.go index 713a8957..89847b5e 100644 --- a/internal/broker/handlers.go +++ b/internal/broker/handlers.go @@ -17,7 +17,6 @@ package broker import ( "encoding/json" "fmt" - "github.com/emitter-io/emitter/internal/security/hash" "strings" "time" @@ -26,6 +25,7 @@ import ( "github.com/emitter-io/emitter/internal/provider/contract" "github.com/emitter-io/emitter/internal/provider/logging" "github.com/emitter-io/emitter/internal/security" + "github.com/emitter-io/emitter/internal/security/hash" "github.com/kelindar/binary" ) @@ -274,9 +274,10 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b // OnMe is a handler that returns information to the connection. func (c *Conn) onMe() (interface{}, bool) { + dial := security.ParseChannel([]byte(c.dial)) return &meResponse{ ID: c.ID(), - Dial: string(c.dial), + Dial: string(dial.Channel), }, true } diff --git a/internal/broker/handlers_test.go b/internal/broker/handlers_test.go index 95034bf4..f48d7d9b 100644 --- a/internal/broker/handlers_test.go +++ b/internal/broker/handlers_test.go @@ -25,7 +25,7 @@ func TestHandlers_onMe(t *testing.T) { conn := netmock.NewConn() nc := s.newConn(conn.Client) - nc.dial = "a/b/c/" + nc.dial = "key/a/b/c/" resp, success := nc.onMe() meResp := resp.(*meResponse) From e64a2e80160ae427a993181ad30516d43cf8358f Mon Sep 17 00:00:00 2001 From: Roman Date: Sun, 27 Jan 2019 18:00:46 +0800 Subject: [PATCH 4/6] added 'emitter/dial' request --- internal/broker/conn.go | 3 +- internal/broker/handlers.go | 52 +++++++++++++++++++++++++++----- internal/broker/handlers_dto.go | 25 ++++++++++++--- internal/broker/handlers_test.go | 29 ++++++++++++++++-- 4 files changed, 93 insertions(+), 16 deletions(-) diff --git a/internal/broker/conn.go b/internal/broker/conn.go index 7e80f578..754da739 100644 --- a/internal/broker/conn.go +++ b/internal/broker/conn.go @@ -41,10 +41,10 @@ type Conn struct { username string // The username provided by the client during MQTT connect. luid security.ID // The locally unique id of the connection. guid string // The globally unique id of the connection. - dial string // The pre-authorized channel with a valid key used for dial. service *Service // The service for this connection. subs *message.Counters // The subscriptions for this connection. measurer stats.Measurer // The measurer to use for monitoring. + dials map[string]string // The map of all pre-authorized dials. } // NewConn creates a new connection. @@ -56,6 +56,7 @@ func (s *Service) newConn(t net.Conn) *Conn { socket: t, subs: message.NewCounters(), measurer: s.measurer, + dials: map[string]string{}, } // Generate a globally unique id as well diff --git a/internal/broker/handlers.go b/internal/broker/handlers.go index 89847b5e..32a408a7 100644 --- a/internal/broker/handlers.go +++ b/internal/broker/handlers.go @@ -17,6 +17,7 @@ package broker import ( "encoding/json" "fmt" + "strconv" "strings" "time" @@ -32,6 +33,7 @@ import ( const ( requestKeygen = 548658350 // hash("keygen") requestPresence = 3869262148 // hash("presence") + requestDial = 1673593207 // hash("dial") requestMe = 2539734036 // hash("me") ) @@ -70,7 +72,7 @@ func (c *Conn) onConnect(packet *mqtt.Connect) bool { // If the password was provided, try to parse the 'dial' string. The format // should be: dial://{key}/{channel}/ if dial, ok := c.dialAndSubscribe(string(packet.Password)); ok { - c.dial = dial // Keep it as string, we want to copy later + c.dials["0"] = dial // Keep it as string, we want to copy later return true } @@ -79,6 +81,8 @@ func (c *Conn) onConnect(packet *mqtt.Connect) bool { // creates a new channel for the dial and subscribes to it if allowed func (c *Conn) dialAndSubscribe(password string) (string, bool) { + + // Parse the dial string channel, valid := security.ParseDial(string(password)) if !valid { return "", false @@ -90,7 +94,7 @@ func (c *Conn) dialAndSubscribe(password string) (string, bool) { return "", false } - // Create a new key for the dial + // Create a new key for the dial. target := fmt.Sprintf("%s%s/", channel.Channel, c.ID()) if err := key.SetTarget(target); err != nil { return "", false @@ -184,8 +188,13 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error { // OnPublish is a handler for MQTT Publish events. func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { exclude := "" - if len(mqttTopic) <= 1 { - mqttTopic = []byte(c.dial) + if len(mqttTopic) == 0 && c.dials != nil { + mqttTopic = []byte(c.dials["0"]) + exclude = c.ID() + } + + if len(mqttTopic) == 1 && c.dials != nil { + mqttTopic = []byte(c.dials[string(mqttTopic[0])]) exclude = c.ID() } @@ -265,6 +274,9 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b case requestMe: resp, ok = c.onMe() return + case requestDial: + resp, ok = c.onDial(payload) + return default: return } @@ -272,12 +284,38 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b // ------------------------------------------------------------------------------------ +// OnDial handles the dial & subscribe. +func (c *Conn) onDial(payload []byte) (interface{}, bool) { + var message dialRequest + if err := json.Unmarshal(payload, &message); err != nil { + return ErrBadRequest, false + } + + dial, ok := c.dialAndSubscribe(fmt.Sprintf("dial://%s/%s", message.Key, message.Channel)) + if !ok { + return ErrBadRequest, false + } + + c.dials[strconv.FormatInt(int64(message.Index), 10)] = dial + channel := security.ParseChannel([]byte(dial)) + return &dialResponse{ + Status: 200, + Channel: string(channel.Channel), + }, true +} + +// ------------------------------------------------------------------------------------ + // OnMe is a handler that returns information to the connection. func (c *Conn) onMe() (interface{}, bool) { - dial := security.ParseChannel([]byte(c.dial)) + dials := make(map[string]string) + for k, v := range c.dials { + dials[k] = string(security.ParseChannel([]byte(v)).Channel) + } + return &meResponse{ - ID: c.ID(), - Dial: string(dial.Channel), + ID: c.ID(), + Dials: dials, }, true } diff --git a/internal/broker/handlers_dto.go b/internal/broker/handlers_dto.go index 0c4db5e6..6ecd744a 100644 --- a/internal/broker/handlers_dto.go +++ b/internal/broker/handlers_dto.go @@ -90,21 +90,36 @@ func (m *keyGenRequest) access() uint8 { // ------------------------------------------------------------------------------------ -type meResponse struct { - ID string `json:"id"` // The private ID of the connection, ret - Dial string `json:"dial,omitempty"` // The dial channel name +type keyGenResponse struct { + Status int `json:"status"` + Key string `json:"key"` + Channel string `json:"channel"` } // ------------------------------------------------------------------------------------ -type keyGenResponse struct { - Status int `json:"status"` +type dialRequest struct { + Index int32 `json:"index"` Key string `json:"key"` Channel string `json:"channel"` } // ------------------------------------------------------------------------------------ +type dialResponse struct { + Status int `json:"status"` + Channel string `json:"channel,omitempty"` +} + +// ------------------------------------------------------------------------------------ + +type meResponse struct { + ID string `json:"id"` // The private ID of the connection, ret + Dials map[string]string `json:"dials,omitempty"` // The dial channel name +} + +// ------------------------------------------------------------------------------------ + type presenceRequest struct { Key string `json:"key"` // The channel key for this request. Channel string `json:"channel"` // The target channel for this request. diff --git a/internal/broker/handlers_test.go b/internal/broker/handlers_test.go index f48d7d9b..36cb3e58 100644 --- a/internal/broker/handlers_test.go +++ b/internal/broker/handlers_test.go @@ -16,6 +16,29 @@ import ( "github.com/stretchr/testify/mock" ) +func TestHandlers_onDial(t *testing.T) { + provider := secmock.NewContractProvider() + contract := new(secmock.Contract) + contract.On("Validate", mock.Anything).Return(true) + provider.On("Get", mock.Anything).Return(contract, true) + license, _ := security.ParseLicense(testLicense) + s := &Service{ + contracts: provider, + subscriptions: message.NewTrie(), + License: license, + presence: make(chan *presenceNotify, 100), + } + + s.Cipher, _ = s.License.Cipher() + conn := netmock.NewConn() + nc := s.newConn(conn.Client) + resp, success := nc.onDial([]byte(`{ "index": 1, "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/" }`)) + meResp := resp.(*dialResponse) + + assert.True(t, success) + assert.Contains(t, meResp.Channel, "a/b/c/") +} + func TestHandlers_onMe(t *testing.T) { license, _ := security.ParseLicense(testLicense) s := &Service{ @@ -25,12 +48,12 @@ func TestHandlers_onMe(t *testing.T) { conn := netmock.NewConn() nc := s.newConn(conn.Client) - nc.dial = "key/a/b/c/" + nc.dials["0"] = "key/a/b/c/" resp, success := nc.onMe() meResp := resp.(*meResponse) assert.True(t, success) - assert.Equal(t, "a/b/c/", meResp.Dial) + assert.Equal(t, "a/b/c/", meResp.Dials["0"]) assert.NotNil(t, resp) assert.NotZero(t, len(meResp.ID)) } @@ -76,7 +99,7 @@ func TestHandlers_onConnect(t *testing.T) { assert.Equal(t, tc.ok, ok, tc.password) if tc.channel != "" { - assert.Contains(t, string(nc.dial), tc.channel) + assert.Contains(t, string(nc.dials["0"]), tc.channel) } }) } From b625b03fa16d145a8db18efcdeac88d842587ec7 Mon Sep 17 00:00:00 2001 From: Roman Date: Sun, 27 Jan 2019 21:31:25 +0800 Subject: [PATCH 5/6] removed dial, added links --- internal/broker/conn.go | 4 +- internal/broker/handlers.go | 160 +++++++++++++------------- internal/broker/handlers_dto.go | 44 +++---- internal/broker/handlers_test.go | 118 ++++++++----------- internal/broker/service.go | 4 +- internal/message/sub.go | 8 +- internal/message/sub_test.go | 6 +- internal/security/channel.go | 32 ++++++ internal/security/channel_test.go | 31 +++++ internal/security/dial.go | 45 -------- internal/security/dial_test.go | 41 ------- internal/security/hash/murmur_test.go | 6 +- internal/security/key.go | 2 +- 13 files changed, 231 insertions(+), 270 deletions(-) delete mode 100644 internal/security/dial.go delete mode 100644 internal/security/dial_test.go diff --git a/internal/broker/conn.go b/internal/broker/conn.go index 754da739..30a7fbfb 100644 --- a/internal/broker/conn.go +++ b/internal/broker/conn.go @@ -44,7 +44,7 @@ type Conn struct { service *Service // The service for this connection. subs *message.Counters // The subscriptions for this connection. measurer stats.Measurer // The measurer to use for monitoring. - dials map[string]string // The map of all pre-authorized dials. + links map[string]string // The map of all pre-authorized links. } // NewConn creates a new connection. @@ -56,7 +56,7 @@ func (s *Service) newConn(t net.Conn) *Conn { socket: t, subs: message.NewCounters(), measurer: s.measurer, - dials: map[string]string{}, + links: map[string]string{}, } // Generate a globally unique id as well diff --git a/internal/broker/handlers.go b/internal/broker/handlers.go index 32a408a7..098c14b3 100644 --- a/internal/broker/handlers.go +++ b/internal/broker/handlers.go @@ -17,7 +17,7 @@ package broker import ( "encoding/json" "fmt" - "strconv" + "regexp" "strings" "time" @@ -33,10 +33,14 @@ import ( const ( requestKeygen = 548658350 // hash("keygen") requestPresence = 3869262148 // hash("presence") - requestDial = 1673593207 // hash("dial") + requestLink = 2667034312 // hash("link") requestMe = 2539734036 // hash("me") ) +var ( + shortcut = regexp.MustCompile("^[a-zA-Z0-9]{1,2}$") +) + // ------------------------------------------------------------------------------------ // Authorize attempts to authorize a channel with its key @@ -63,58 +67,7 @@ func (c *Conn) authorize(channel *security.Channel, permission uint8) (contract. // onConnect handles the connection authorization func (c *Conn) onConnect(packet *mqtt.Connect) bool { c.username = string(packet.Username) - - // If there's no password provided, we're done - if len(packet.Password) == 0 { - return true - } - - // If the password was provided, try to parse the 'dial' string. The format - // should be: dial://{key}/{channel}/ - if dial, ok := c.dialAndSubscribe(string(packet.Password)); ok { - c.dials["0"] = dial // Keep it as string, we want to copy later - return true - } - - return false -} - -// creates a new channel for the dial and subscribes to it if allowed -func (c *Conn) dialAndSubscribe(password string) (string, bool) { - - // Parse the dial string - channel, valid := security.ParseDial(string(password)) - if !valid { - return "", false - } - - // Check the authorization and permissions - _, key, allowed := c.authorize(channel, security.AllowDial) - if !allowed { - return "", false - } - - // Create a new key for the dial. - target := fmt.Sprintf("%s%s/", channel.Channel, c.ID()) - if err := key.SetTarget(target); err != nil { - return "", false - } - - // Encrypt the key for storing - encryptedKey, err := c.service.Cipher.EncryptKey(key) - if err != nil { - return "", false - } - - // Auto-subscribe to the dial if allowed - if key.HasPermission(security.AllowRead) { - channel.Channel = []byte(target) - channel.Query = append(channel.Query, hash.Of([]byte(c.ID()))) - ssid := message.NewSsid(key.Contract(), channel) - c.Subscribe(ssid, channel.Channel) - } - - return fmt.Sprintf("%s/%s", encryptedKey, target), true + return true } // ------------------------------------------------------------------------------------ @@ -135,7 +88,7 @@ func (c *Conn) onSubscribe(mqttTopic []byte) *Error { } // Subscribe the client to the channel - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) c.Subscribe(ssid, channel.Channel) // In case of ttl, check the key provides the permission to store (soft permission) @@ -177,7 +130,7 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error { } // Unsubscribe the client from the channel - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) c.Unsubscribe(ssid, channel.Channel) c.track(contract) return nil @@ -188,13 +141,8 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error { // OnPublish is a handler for MQTT Publish events. func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { exclude := "" - if len(mqttTopic) == 0 && c.dials != nil { - mqttTopic = []byte(c.dials["0"]) - exclude = c.ID() - } - - if len(mqttTopic) == 1 && c.dials != nil { - mqttTopic = []byte(c.dials[string(mqttTopic[0])]) + if len(mqttTopic) <= 2 && c.links != nil { + mqttTopic = []byte(c.links[string(mqttTopic)]) exclude = c.ID() } @@ -223,7 +171,7 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error { // Create a new message msg := message.New( - message.NewSsid(key.Contract(), channel), + message.NewSsid(key.Contract(), channel.Query), channel.Channel, payload, ) @@ -274,8 +222,8 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b case requestMe: resp, ok = c.onMe() return - case requestDial: - resp, ok = c.onDial(payload) + case requestLink: + resp, ok = c.onLink(payload) return default: return @@ -284,38 +232,88 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b // ------------------------------------------------------------------------------------ -// OnDial handles the dial & subscribe. -func (c *Conn) onDial(payload []byte) (interface{}, bool) { - var message dialRequest - if err := json.Unmarshal(payload, &message); err != nil { +// onLink handles a request to create a link. +func (c *Conn) onLink(payload []byte) (interface{}, bool) { + var request linkRequest + if err := json.Unmarshal(payload, &request); err != nil { return ErrBadRequest, false } - dial, ok := c.dialAndSubscribe(fmt.Sprintf("dial://%s/%s", message.Key, message.Channel)) - if !ok { + // Check whether the name is a valid shortcut name + if !shortcut.Match([]byte(request.Name)) { + return ErrLinkInvalid, false + } + + // Make the channel from the request or try to make a private one + channel := security.MakeChannel(request.Key, request.Channel) + if request.Private { + channel = c.makePrivateChannel(request.Key, request.Channel) + } + + // Ensures that the channel requested is valid + if channel == nil || channel.ChannelType == security.ChannelInvalid { return ErrBadRequest, false } - c.dials[strconv.FormatInt(int64(message.Index), 10)] = dial - channel := security.ParseChannel([]byte(dial)) - return &dialResponse{ + // Create the link with the name and set the full channel to it + c.links[request.Name] = channel.String() + + // If an auto-subscribe was requested and the key has read permissions, subscribe + if _, key, allowed := c.authorize(channel, security.AllowRead); allowed && request.Subscribe { + c.Subscribe(message.NewSsid(key.Contract(), channel.Query), channel.Channel) + } + + return &linkResponse{ Status: 200, - Channel: string(channel.Channel), + Name: request.Name, + Channel: channel.SafeString(), }, true } +// makePrivateChannel creates a private channel and an appropriate key. +func (c *Conn) makePrivateChannel(chanKey, chanName string) *security.Channel { + channel := security.MakeChannel(chanKey, chanName) + if channel.ChannelType != security.ChannelStatic { + return nil + } + + // Make sure we can actually extend it + _, key, allowed := c.authorize(channel, security.AllowExtend) + if !allowed { + return nil + } + + // Create a new key for the private link + target := fmt.Sprintf("%s%s/", channel.Channel, c.ID()) + if err := key.SetTarget(target); err != nil { + return nil + } + + // Encrypt the key for storing + encryptedKey, err := c.service.Cipher.EncryptKey(key) + if err != nil { + return nil + } + + // Create the private channel + channel.Channel = []byte(target) + channel.Query = append(channel.Query, hash.Of([]byte(c.ID()))) + channel.Key = []byte(encryptedKey) + return channel +} + // ------------------------------------------------------------------------------------ // OnMe is a handler that returns information to the connection. func (c *Conn) onMe() (interface{}, bool) { - dials := make(map[string]string) - for k, v := range c.dials { - dials[k] = string(security.ParseChannel([]byte(v)).Channel) + links := make(map[string]string) + for k, v := range c.links { + links[k] = security.ParseChannel([]byte(v)).SafeString() } return &meResponse{ ID: c.ID(), - Dials: dials, + Links: links, }, true } @@ -470,7 +468,7 @@ func (c *Conn) onPresence(payload []byte) (interface{}, bool) { } // Create the ssid for the presence - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) // Check if the client is interested in subscribing/unsubscribing from changes. if msg.Changes { diff --git a/internal/broker/handlers_dto.go b/internal/broker/handlers_dto.go index 6ecd744a..f50412b0 100644 --- a/internal/broker/handlers_dto.go +++ b/internal/broker/handlers_dto.go @@ -35,15 +35,16 @@ func (e *Error) Error() string { return e.Message } // Represents a set of errors used in the handlers. var ( - ErrBadRequest = &Error{Status: 400, Message: "The request was invalid or cannot be otherwise served."} - ErrUnauthorized = &Error{Status: 401, Message: "The security key provided is not authorized to perform this operation."} - ErrPaymentRequired = &Error{Status: 402, Message: "The request can not be served, as the payment is required to proceed."} - ErrForbidden = &Error{Status: 403, Message: "The request is understood, but it has been refused or access is not allowed."} - ErrNotFound = &Error{Status: 404, Message: "The resource requested does not exist."} - ErrServerError = &Error{Status: 500, Message: "An unexpected condition was encountered and no more specific message is suitable."} - ErrNotImplemented = &Error{Status: 501, Message: "The server either does not recognize the request method, or it lacks the ability to fulfill the request."} - ErrTargetInvalid = &Error{Status: 400, Message: "Channel should end with `/` for strict types or `/#/` for wildcards."} - ErrTargetTooLong = &Error{Status: 400, Message: "Channel can not have more than 23 parts."} + ErrBadRequest = &Error{Status: 400, Message: "the request was invalid or cannot be otherwise served"} + ErrUnauthorized = &Error{Status: 401, Message: "the security key provided is not authorized to perform this operation"} + ErrPaymentRequired = &Error{Status: 402, Message: "the request can not be served, as the payment is required to proceed"} + ErrForbidden = &Error{Status: 403, Message: "the request is understood, but it has been refused or access is not allowed"} + ErrNotFound = &Error{Status: 404, Message: "the resource requested does not exist"} + ErrServerError = &Error{Status: 500, Message: "an unexpected condition was encountered and no more specific message is suitable"} + ErrNotImplemented = &Error{Status: 501, Message: "the server either does not recognize the request method, or it lacks the ability to fulfill the request"} + ErrTargetInvalid = &Error{Status: 400, Message: "channel should end with `/` for strict types or `/#/` for wildcards"} + ErrTargetTooLong = &Error{Status: 400, Message: "channel can not have more than 23 parts."} + ErrLinkInvalid = &Error{Status: 400, Message: "the link must be an alphanumeric string of 1 or 2 characters"} ) // ------------------------------------------------------------------------------------ @@ -78,8 +79,8 @@ func (m *keyGenRequest) access() uint8 { required |= security.AllowLoad case 'p': required |= security.AllowPresence - case 'd': - required |= security.AllowDial + case 'e': + required |= security.AllowExtend case 'x': required |= security.AllowExecute } @@ -98,24 +99,27 @@ type keyGenResponse struct { // ------------------------------------------------------------------------------------ -type dialRequest struct { - Index int32 `json:"index"` - Key string `json:"key"` - Channel string `json:"channel"` +type linkRequest struct { + Name string `json:"name"` // The name of the shortcut, max 2 characters. + Key string `json:"key"` // The key for the channel. + Channel string `json:"channel"` // The channel name for the shortcut. + Subscribe bool `json:"subscribe"` // Specifies whether the broker should auto-subscribe. + Private bool `json:"private"` // Specifies whether the broker should generate a private link. } // ------------------------------------------------------------------------------------ -type dialResponse struct { - Status int `json:"status"` - Channel string `json:"channel,omitempty"` +type linkResponse struct { + Status int `json:"status"` // The status of the response. + Name string `json:"name,omitempty"` // The name of the shortcut, max 2 characters. + Channel string `json:"channel,omitempty"` // The channel which was registered. } // ------------------------------------------------------------------------------------ type meResponse struct { - ID string `json:"id"` // The private ID of the connection, ret - Dials map[string]string `json:"dials,omitempty"` // The dial channel name + ID string `json:"id"` // The private ID of the connection. + Links map[string]string `json:"links,omitempty"` // The set of pre-defined channels. } // ------------------------------------------------------------------------------------ diff --git a/internal/broker/handlers_test.go b/internal/broker/handlers_test.go index 36cb3e58..0f36eea4 100644 --- a/internal/broker/handlers_test.go +++ b/internal/broker/handlers_test.go @@ -1,7 +1,6 @@ package broker import ( - "github.com/emitter-io/emitter/internal/network/mqtt" "testing" "github.com/emitter-io/emitter/internal/message" @@ -16,72 +15,36 @@ import ( "github.com/stretchr/testify/mock" ) -func TestHandlers_onDial(t *testing.T) { - provider := secmock.NewContractProvider() - contract := new(secmock.Contract) - contract.On("Validate", mock.Anything).Return(true) - provider.On("Get", mock.Anything).Return(contract, true) - license, _ := security.ParseLicense(testLicense) - s := &Service{ - contracts: provider, - subscriptions: message.NewTrie(), - License: license, - presence: make(chan *presenceNotify, 100), - } - - s.Cipher, _ = s.License.Cipher() - conn := netmock.NewConn() - nc := s.newConn(conn.Client) - resp, success := nc.onDial([]byte(`{ "index": 1, "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/" }`)) - meResp := resp.(*dialResponse) - - assert.True(t, success) - assert.Contains(t, meResp.Channel, "a/b/c/") -} - -func TestHandlers_onMe(t *testing.T) { - license, _ := security.ParseLicense(testLicense) - s := &Service{ - subscriptions: message.NewTrie(), - License: license, - } - - conn := netmock.NewConn() - nc := s.newConn(conn.Client) - nc.dials["0"] = "key/a/b/c/" - resp, success := nc.onMe() - meResp := resp.(*meResponse) - - assert.True(t, success) - assert.Equal(t, "a/b/c/", meResp.Dials["0"]) - assert.NotNil(t, resp) - assert.NotZero(t, len(meResp.ID)) -} - -func TestHandlers_onConnect(t *testing.T) { - license, _ := security.ParseLicense(testLicense) +func TestHandlers_onLink(t *testing.T) { tests := []struct { - password string - channel string - ok bool + packet string + channel string + success bool }{ - {password: "", ok: true}, - {password: "dial://k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm/a/b/c/", channel: "a/b/c/CONNECTION_ID/", ok: true}, - - {password: "dial://a/b/c/"}, - {password: "a/b/c/"}, - {password: "agsew350290"}, - {password: "1.2342/24/225"}, - {password: "fake://k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm/a/b/c/"}, + { + packet: `{ "name": "AB", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true, "subscribe": true }`, + channel: "a/b/c/", + success: true, + }, + { + packet: `{ "name": "AB", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/"}`, + channel: "a/b/c/", + success: true, + }, + {packet: `{ "name": "ABC", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true, "subscribe": true }`}, + {packet: `{ "name": "", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true, "subscribe": true }`}, + {packet: `{"key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true, "subscribe": true }`}, + {packet: `{ "name": "AB", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "---", "private": true, "subscribe": true }`}, + {packet: `{ "name": "AB", "key": "xxx", "channel": "a/b/c/", "private": true, "subscribe": true }`}, } for _, tc := range tests { - t.Run(tc.password, func(*testing.T) { + t.Run(tc.packet, func(*testing.T) { provider := secmock.NewContractProvider() contract := new(secmock.Contract) contract.On("Validate", mock.Anything).Return(true) - contract.On("Stats").Return(usage.NewMeter(0)) provider.On("Get", mock.Anything).Return(contract, true) + license, _ := security.ParseLicense(testLicense) s := &Service{ contracts: provider, subscriptions: message.NewTrie(), @@ -89,22 +52,38 @@ func TestHandlers_onConnect(t *testing.T) { presence: make(chan *presenceNotify, 100), } + s.Cipher, _ = s.License.Cipher() conn := netmock.NewConn() nc := s.newConn(conn.Client) - nc.guid = "CONNECTION_ID" - s.Cipher, _ = s.License.Cipher() - ok := nc.onConnect(&mqtt.Connect{ - Password: []byte(tc.password), - }) - assert.Equal(t, tc.ok, ok, tc.password) - if tc.channel != "" { - assert.Contains(t, string(nc.dials["0"]), tc.channel) + resp, ok := nc.onLink([]byte(tc.packet)) + assert.Equal(t, tc.success, ok) + if tc.success { + assert.Contains(t, resp.(*linkResponse).Channel, tc.channel) } }) } } +func TestHandlers_onMe(t *testing.T) { + license, _ := security.ParseLicense(testLicense) + s := &Service{ + subscriptions: message.NewTrie(), + License: license, + } + + conn := netmock.NewConn() + nc := s.newConn(conn.Client) + nc.links["0"] = "key/a/b/c/" + resp, success := nc.onMe() + meResp := resp.(*meResponse) + + assert.True(t, success) + assert.Equal(t, "a/b/c/", meResp.Links["0"]) + assert.NotNil(t, resp) + assert.NotZero(t, len(meResp.ID)) +} + func TestHandlers_onSubscribeUnsubscribe(t *testing.T) { license, _ := security.ParseLicense(testLicense) tests := []struct { @@ -223,7 +202,7 @@ func TestHandlers_onSubscribeUnsubscribe(t *testing.T) { // Search for the ssid. channel := security.ParseChannel([]byte(tc.channel)) key, _ := s.Cipher.DecryptKey(channel.Key) - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) subscribers := s.subscriptions.Lookup(ssid) assert.Equal(t, tc.subCount, len(subscribers)) @@ -571,6 +550,11 @@ func TestHandlers_onEmitterRequest(t *testing.T) { query: []uint32{requestMe}, success: true, }, + { + channel: "link", + query: []uint32{requestLink}, + success: false, + }, } for _, tc := range tests { diff --git a/internal/broker/service.go b/internal/broker/service.go index ba0ea085..3306f63f 100644 --- a/internal/broker/service.go +++ b/internal/broker/service.go @@ -364,7 +364,7 @@ func (s *Service) onHTTPPresence(w http.ResponseWriter, r *http.Request) { } // Create the ssid for the presence - ssid := message.NewSsid(key.Contract(), channel) + ssid := message.NewSsid(key.Contract(), channel.Query) now := time.Now().UTC().Unix() who := getAllPresence(s, ssid) resp, err := json.Marshal(&presenceResponse{ @@ -457,7 +457,7 @@ func (s *Service) selfPublish(channelName string, payload []byte) { channel := security.ParseChannel([]byte("emitter/" + channelName)) if channel.ChannelType == security.ChannelStatic { s.publish(message.New( - message.NewSsid(s.License.Contract, channel), + message.NewSsid(s.License.Contract, channel.Query), channel.Channel, payload, ), "") diff --git a/internal/message/sub.go b/internal/message/sub.go index ea0c878f..ab9b47c5 100644 --- a/internal/message/sub.go +++ b/internal/message/sub.go @@ -20,8 +20,6 @@ import ( "sync" "time" "unsafe" - - "github.com/emitter-io/emitter/internal/security" ) // Various constant parts of the SSID. @@ -40,10 +38,10 @@ var Query = Ssid{system, query} type Ssid []uint32 // NewSsid creates a new SSID. -func NewSsid(contract uint32, c *security.Channel) Ssid { - ssid := make([]uint32, 0, len(c.Query)+1) +func NewSsid(contract uint32, query []uint32) Ssid { + ssid := make([]uint32, 0, len(query)+1) ssid = append(ssid, uint32(contract)) - ssid = append(ssid, c.Query...) + ssid = append(ssid, query...) return ssid } diff --git a/internal/message/sub_test.go b/internal/message/sub_test.go index beeecc39..77ce8735 100644 --- a/internal/message/sub_test.go +++ b/internal/message/sub_test.go @@ -8,7 +8,7 @@ import ( ) func BenchmarkSsidEncode(b *testing.B) { - ssid := NewSsid(0, &security.Channel{Query: []uint32{56498455, 44565213, 46512350, 18204498}}) + ssid := NewSsid(0, []uint32{56498455, 44565213, 46512350, 18204498}) b.ReportAllocs() b.ResetTimer() @@ -32,7 +32,7 @@ func TestSsid(t *testing.T) { ChannelType: security.ChannelStatic, } - ssid := NewSsid(0, &c) + ssid := NewSsid(0, c.Query) assert.Equal(t, uint32(0), ssid.Contract()) assert.Equal(t, uint32(0x2c), ssid.GetHashCode()) } @@ -53,7 +53,7 @@ func TestSsidEncode(t *testing.T) { } for _, tc := range tests { - ssid := NewSsid(0, &security.Channel{Query: tc.ssid}) + ssid := NewSsid(0, tc.ssid) assert.Equal(t, tc.expected, ssid.Encode()) } } diff --git a/internal/security/channel.go b/internal/security/channel.go index 8f866f7b..4295b7d2 100644 --- a/internal/security/channel.go +++ b/internal/security/channel.go @@ -15,6 +15,7 @@ package security import ( + "fmt" "strconv" "time" "unsafe" @@ -75,6 +76,32 @@ func (c *Channel) Window() (time.Time, time.Time) { return toUnix(u0), toUnix(u1) } +// SafeString returns a string representation of the channel without the key. +func (c *Channel) SafeString() string { + text := string(c.Channel) + if len(c.Options) == 0 { + return text + } + + text += "?" + for i, v := range c.Options { + if i > 0 { + text += "&" + } + + text += v.Key + "=" + v.Value + } + return text +} + +// String returns a string representation of the channel. +func (c *Channel) String() string { + text := string(c.Key) + text += "/" + text += c.SafeString() + return text +} + // Converts the time to Unix Time with validation. func toUnix(t int64) time.Time { if t == 0 || t < MinTime || t > MaxTime { @@ -97,6 +124,11 @@ func (c *Channel) getOption(name string) (int64, bool) { return 0, false } +// MakeChannel attempts to parse the channel from the key and channel strings. +func MakeChannel(key, channel string) *Channel { + return ParseChannel([]byte(fmt.Sprintf("%s/%s", key, channel))) +} + // ParseChannel attempts to parse the channel from the underlying slice. func ParseChannel(text []byte) (channel *Channel) { channel = new(Channel) diff --git a/internal/security/channel_test.go b/internal/security/channel_test.go index f278865e..715ba1e7 100644 --- a/internal/security/channel_test.go +++ b/internal/security/channel_test.go @@ -195,3 +195,34 @@ func TestGetChannelTarget(t *testing.T) { assert.Equal(t, tc.target, target) } } + +func TestMakeChannel(t *testing.T) { + tests := []struct { + key string + channel string + }{ + {key: "key1", channel: "emitter/a/"}, + } + + for _, tc := range tests { + channel := MakeChannel(tc.key, tc.channel) + assert.Equal(t, tc.key, string(channel.Key)) + assert.Equal(t, tc.channel, string(channel.Channel)) + } +} + +func TestChannelString(t *testing.T) { + tests := []struct { + channel string + }{ + {channel: "emitter/a/?last=42&abc=9"}, + {channel: "emitter/a/?last=1200"}, + {channel: "emitter/a/?last=1200a"}, + {channel: "emitter/a/"}, + } + + for _, tc := range tests { + channel := ParseChannel([]byte(tc.channel)) + assert.Equal(t, tc.channel, channel.String()) + } +} diff --git a/internal/security/dial.go b/internal/security/dial.go deleted file mode 100644 index b802f117..00000000 --- a/internal/security/dial.go +++ /dev/null @@ -1,45 +0,0 @@ -/********************************************************************************** -* Copyright (c) 2009-2017 Misakai Ltd. -* This program is free software: you can redistribute it and/or modify it under the -* terms of the GNU Affero General Public License as published by the Free Software -* Foundation, either version 3 of the License, or(at your option) any later version. -* -* This program is distributed in the hope that it will be useful, but WITHOUT ANY -* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A -* PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. -* -* You should have received a copy of the GNU Affero General Public License along -* with this program. If not, see. -************************************************************************************/ - -package security - -import ( - "regexp" -) - -// This is a strict format for the dial string -var dialFormat = regexp.MustCompile(`^(dial)\:\/\/(.+)$`) - -// ParseDial parses a pre-authorized channel key -func ParseDial(password string) (*Channel, bool) { - parts := dialFormat.FindStringSubmatch(password) - if len(parts) != 3 { - return nil, false // Invalid channel - } - - // Get the scheme and channel and make sure they're valid - scheme := parts[1] - channel := ParseChannel([]byte(parts[2])) - if len(scheme) == 0 || channel == nil || channel.ChannelType == ChannelInvalid { - return nil, false - } - - // For dial to work, the channel must be static - if scheme == "dial" && channel.ChannelType == ChannelStatic { - return channel, true - } - - // Safe default - return nil, false -} diff --git a/internal/security/dial_test.go b/internal/security/dial_test.go deleted file mode 100644 index 5e7c66d2..00000000 --- a/internal/security/dial_test.go +++ /dev/null @@ -1,41 +0,0 @@ -package security - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestParseDial(t *testing.T) { - tests := []struct { - input string - expect string - success bool - }{ - {input: "dial://emitter/a/?ttl=42&abc=9", expect: "a/", success: true}, - {input: "dial://emitter/a/?ttl=1200", expect: "a/", success: true}, - {input: "dial://emitter/a/?ttl=1200a", expect: "a/", success: true}, - {input: "dial://emitter/a/b/c/", expect: "a/b/c/", success: true}, - - {input: "auto:/test"}, - {input: "auto://"}, - {input: "auto://--------2345205982)@#(%&*@#//)//%"}, - {input: "dial://emitter/+/"}, - {input: "emitter/a/?ttl=42&abc=9"}, - {input: "emitter/a/?ttl=1200"}, - {input: "emitter/a/?ttl=1200a"}, - {input: "emitter/a/"}, - {input: "err://emitter/a/?ttl=42&abc=9"}, - {input: "err://emitter/a/?ttl=1200"}, - {input: "err://emitter/a/?ttl=1200a"}, - {input: "err://emitter/a/"}, - } - - for _, tc := range tests { - channel, ok := ParseDial(tc.input) - assert.Equal(t, ok, tc.success, tc.input) - if ok && tc.expect != "" { - assert.Equal(t, tc.expect, string(channel.Channel), tc.input) - } - } -} diff --git a/internal/security/hash/murmur_test.go b/internal/security/hash/murmur_test.go index c8d7685a..6beaedb5 100644 --- a/internal/security/hash/murmur_test.go +++ b/internal/security/hash/murmur_test.go @@ -36,9 +36,9 @@ func TestMeHash(t *testing.T) { assert.Equal(t, uint32(2539734036), h) } -func TestDialHash(t *testing.T) { - h := Of([]byte("dial")) - assert.Equal(t, uint32(1673593207), h) +func TestLinkHash(t *testing.T) { + h := Of([]byte("link")) + assert.Equal(t, uint32(2667034312), h) } func TestGetHash(t *testing.T) { diff --git a/internal/security/key.go b/internal/security/key.go index 78326e39..d416e361 100644 --- a/internal/security/key.go +++ b/internal/security/key.go @@ -31,7 +31,7 @@ const ( AllowStore = uint8(1 << 3) // Key should be allowed to write to the message history of the target channel. AllowLoad = uint8(1 << 4) // Key should be allowed to write to read the message history of the target channel. AllowPresence = uint8(1 << 5) // Key should be allowed to query the presence on the target channel. - AllowDial = uint8(1 << 6) // Key should be allowed to create a 'dial' sub-channel. + AllowExtend = uint8(1 << 6) // Key should be allowed to create sub-channels by extending an existing one. AllowExecute = uint8(1 << 7) // Key should be allowed to execute code. (RESERVED) AllowReadWrite = AllowRead | AllowWrite // Key should be allowed to read and write to the target channel. AllowStoreLoad = AllowStore | AllowLoad // Key should be allowed to read and write the message history. From 829fe3c973f35fe094d4f45e626d8c7a84514d2d Mon Sep 17 00:00:00 2001 From: Roman Date: Sun, 27 Jan 2019 21:51:25 +0800 Subject: [PATCH 6/6] added to the integration test --- internal/broker/service_test.go | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/internal/broker/service_test.go b/internal/broker/service_test.go index a19ada41..98a65510 100644 --- a/internal/broker/service_test.go +++ b/internal/broker/service_test.go @@ -233,6 +233,32 @@ func TestPubsub(t *testing.T) { assert.Equal(t, mqtt.TypeOfUnsuback, pkt.Type()) } + { // Create a private link + msg := mqtt.Publish{ + Header: &mqtt.StaticHeader{QOS: 0}, + Topic: []byte("emitter/link/"), + Payload: []byte(`{ "name": "hi", "key": "k44Ss59ZSxg6Zyz39kLwN-2t5AETnGpm", "channel": "a/b/c/", "private": true }`), + } + _, err := msg.EncodeTo(cli) + assert.NoError(t, err) + } + + { // Read the link response + pkt, err := mqtt.DecodePacket(cli) + assert.NoError(t, err) + assert.Equal(t, mqtt.TypeOfPublish, pkt.Type()) + } + + { // Publish a message to a link + msg := mqtt.Publish{ + Header: &mqtt.StaticHeader{QOS: 0}, + Topic: []byte("hi"), + Payload: []byte("hello world"), + } + _, err := msg.EncodeTo(cli) + assert.NoError(t, err) + } + { // Disconnect from the broker disconnect := mqtt.Disconnect{} n, err := disconnect.EncodeTo(cli)