Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Pre-registered links and simplified client-server setup #183

Merged
merged 6 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions internal/broker/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type Conn struct {
username string // The username provided by the client during MQTT connect.
luid security.ID // The locally unique id of the connection.
guid string // The globally unique id of the connection.
dial string // The pre-authorized channel with a valid key used for dial.
service *Service // The service for this connection.
subs *message.Counters // The subscriptions for this connection.
measurer stats.Measurer // The measurer to use for monitoring.
Expand Down Expand Up @@ -137,11 +138,13 @@ func (c *Conn) onReceive(msg mqtt.Message) error {

// We got an attempt to connect to MQTT.
case mqtt.TypeOfConnect:
packet := msg.(*mqtt.Connect)
c.username = string(packet.Username)
var result uint8
if !c.onConnect(msg.(*mqtt.Connect)) {
result = 0x05 // Unauthorized
}

// Write the ack
ack := mqtt.Connack{ReturnCode: 0x00}
ack := mqtt.Connack{ReturnCode: result}
if _, err := ack.EncodeTo(c.socket); err != nil {
return err
}
Expand Down
170 changes: 99 additions & 71 deletions internal/broker/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ package broker

import (
"encoding/json"
"fmt"
"github.com/emitter-io/emitter/internal/security/hash"
"strings"
"time"

"github.com/emitter-io/emitter/internal/message"
"github.com/emitter-io/emitter/internal/network/mqtt"
"github.com/emitter-io/emitter/internal/provider/contract"
"github.com/emitter-io/emitter/internal/provider/logging"
"github.com/emitter-io/emitter/internal/security"
"github.com/kelindar/binary"
Expand All @@ -33,39 +37,100 @@ const (

// ------------------------------------------------------------------------------------

// OnSubscribe is a handler for MQTT Subscribe events.
func (c *Conn) onSubscribe(mqttTopic []byte) *Error {

// Parse the channel
channel := security.ParseChannel(mqttTopic)
if channel.ChannelType == security.ChannelInvalid {
return ErrBadRequest
}
// Authorize attempts to authorize a channel with its key
func (c *Conn) authorize(channel *security.Channel, permission uint32) (contract.Contract, security.Key, bool) {

// Attempt to parse the key
key, err := c.service.Cipher.DecryptKey(channel.Key)
if err != nil || key.IsExpired() {
return ErrUnauthorized
return nil, nil, false
}

// Attempt to fetch the contract using the key. Underneath, it's cached.
contract, contractFound := c.service.contracts.Get(key.Contract())
if !contractFound {
return ErrNotFound
if !contractFound || !contract.Validate(key) || !key.HasPermission(permission) || !key.ValidateChannel(channel) {
return nil, nil, false
}

// Validate the contract
if !contract.Validate(key) {
return ErrUnauthorized
// Return the contract and the key
return contract, key, true
}

// ------------------------------------------------------------------------------------

// onConnect handles the connection authorization
func (c *Conn) onConnect(packet *mqtt.Connect) bool {
c.username = string(packet.Username)

// If there's no password provided, we're done
if len(packet.Password) == 0 {
return true
}

// Check if the key has the permission to read from here
if !key.HasPermission(security.AllowRead) {
return ErrUnauthorized
// If the password was provided, evalute our scheme. The format should be:
// dial://{key}/{channel}/
scheme, channel := security.ParsePassword(string(packet.Password))
if len(scheme) == 0 || channel == nil || channel.ChannelType == security.ChannelInvalid {
return false
}

// If it's a dial, we need to append the connection ID and generate a new key
if scheme == "dial" {
if dial, ok := c.dialAndSubscribe(channel); ok {
c.dial = dial // Keep it as string, we want to copy later
return true
}
}

return false
}

// creates a new channel for the dial and subscribes to it if allowed
func (c *Conn) dialAndSubscribe(channel *security.Channel) (string, bool) {

// Check the authorization and permissions
_, key, allowed := c.authorize(channel, security.AllowAny)
if !allowed {
return "", false
}

// Create a new key for the dial
target := fmt.Sprintf("%s%s/", channel.Channel, c.ID())
if err := key.SetTarget(target); err != nil {
return "", false
}

// Encrypt the key for storing
encryptedKey, err := c.service.Cipher.EncryptKey(key)
if err != nil {
return "", false
}

// Auto-subscribe to the dial if allowed
if key.HasPermission(security.AllowRead) {
channel.Channel = []byte(target)
channel.Query = append(channel.Query, hash.Of([]byte(c.ID())))
ssid := message.NewSsid(key.Contract(), channel)
c.Subscribe(ssid, channel.Channel)
}

return fmt.Sprintf("%s/%s", encryptedKey, target), true
}

// ------------------------------------------------------------------------------------

// OnSubscribe is a handler for MQTT Subscribe events.
func (c *Conn) onSubscribe(mqttTopic []byte) *Error {

// Parse the channel
channel := security.ParseChannel(mqttTopic)
if channel.ChannelType == security.ChannelInvalid {
return ErrBadRequest
}

// Check if the key has the permission for the required channel
if !key.ValidateChannel(channel) {
// Check the authorization and permissions
contract, key, allowed := c.authorize(channel, security.AllowRead)
if !allowed {
return ErrUnauthorized
}

Expand Down Expand Up @@ -105,30 +170,9 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error {
return ErrBadRequest
}

// Attempt to parse the key
key, err := c.service.Cipher.DecryptKey(channel.Key)
if err != nil || key.IsExpired() {
return ErrUnauthorized
}

// Attempt to fetch the contract using the key. Underneath, it's cached.
contract, contractFound := c.service.contracts.Get(key.Contract())
if !contractFound {
return ErrNotFound
}

// Validate the contract
if !contract.Validate(key) {
return ErrUnauthorized
}

// Check if the key has the permission to read from here
if !key.HasPermission(security.AllowRead) {
return ErrUnauthorized
}

// Check if the key has the permission for the required channel
if !key.ValidateChannel(channel) {
// Check the authorization and permissions
contract, key, allowed := c.authorize(channel, security.AllowRead)
if !allowed {
return ErrUnauthorized
}

Expand All @@ -143,8 +187,13 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error {

// OnPublish is a handler for MQTT Publish events.
func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error {
exclude := ""
if len(mqttTopic) <= 1 {
mqttTopic = []byte(c.dial)
exclude = c.ID()
}

// Parse the channel
// Make sure we have a valid channel
channel := security.ParseChannel(mqttTopic)
if channel.ChannelType == security.ChannelInvalid {
return ErrBadRequest
Expand All @@ -161,30 +210,9 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error {
return nil
}

// Attempt to parse the key
key, err := c.service.Cipher.DecryptKey(channel.Key)
if err != nil || key.IsExpired() {
return ErrUnauthorized
}

// Attempt to fetch the contract using the key. Underneath, it's cached.
contract, contractFound := c.service.contracts.Get(key.Contract())
if !contractFound {
return ErrNotFound
}

// Validate the contract
if !contract.Validate(key) {
return ErrUnauthorized
}

// Check if the key has the permission to write here
if !key.HasPermission(security.AllowWrite) {
return ErrUnauthorized
}

// Check if the key has the permission for the required channel
if !key.ValidateChannel(channel) {
// Check the authorization and permissions
contract, key, allowed := c.authorize(channel, security.AllowWrite)
if !allowed {
return ErrUnauthorized
}

Expand All @@ -202,7 +230,7 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error {
}

// Iterate through all subscribers and send them the message
size := c.service.publish(msg)
size := c.service.publish(msg, exclude)

// Write the monitoring information
c.track(contract)
Expand Down Expand Up @@ -250,9 +278,9 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b

// OnMe is a handler that returns information to the connection.
func (c *Conn) onMe() (interface{}, bool) {
// Success, return the response
return &meResponse{
ID: c.ID(),
ID: c.ID(),
Dial: string(c.dial),
kelindar marked this conversation as resolved.
Show resolved Hide resolved
}, true
}

Expand Down
3 changes: 2 additions & 1 deletion internal/broker/handlers_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func (m *keyGenRequest) access() uint32 {
// ------------------------------------------------------------------------------------

type meResponse struct {
ID string `json:"id"`
ID string `json:"id"` // The private ID of the connection, ret
Dial string `json:"dial,omitempty"` // The dial channel name
}

// ------------------------------------------------------------------------------------
Expand Down
58 changes: 54 additions & 4 deletions internal/broker/handlers_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package broker

import (
"github.com/emitter-io/emitter/internal/network/mqtt"
"testing"

"github.com/emitter-io/emitter/internal/message"
Expand All @@ -24,14 +25,63 @@ func TestHandlers_onMe(t *testing.T) {

conn := netmock.NewConn()
nc := s.newConn(conn.Client)
nc.dial = "a/b/c/"
resp, success := nc.onMe()
meResp := resp.(*meResponse)

assert.Equal(t, success, true, success)
assert.True(t, success)
assert.Equal(t, "a/b/c/", meResp.Dial)
assert.NotNil(t, resp)
assert.NotZero(t, len(meResp.ID))
}

func TestHandlers_onConnect(t *testing.T) {
license, _ := security.ParseLicense(testLicense)
tests := []struct {
password string
channel string
ok bool
}{
{password: "", ok: true},
{password: "dial://0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/", channel: "a/b/c/CONNECTION_ID/", ok: true},

{password: "dial://a/b/c/"},
{password: "a/b/c/"},
{password: "agsew350290"},
{password: "1.2342/24/225"},
{password: "fake://0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/"},
}

for _, tc := range tests {
t.Run(tc.password, func(*testing.T) {
provider := secmock.NewContractProvider()
contract := new(secmock.Contract)
contract.On("Validate", mock.Anything).Return(true)
contract.On("Stats").Return(usage.NewMeter(0))
provider.On("Get", mock.Anything).Return(contract, true)
s := &Service{
contracts: provider,
subscriptions: message.NewTrie(),
License: license,
presence: make(chan *presenceNotify, 100),
}

conn := netmock.NewConn()
nc := s.newConn(conn.Client)
nc.guid = "CONNECTION_ID"
s.Cipher, _ = s.License.Cipher()
ok := nc.onConnect(&mqtt.Connect{
Password: []byte(tc.password),
})

assert.Equal(t, tc.ok, ok, tc.password)
if tc.channel != "" {
assert.Contains(t, string(nc.dial), tc.channel)
}
})
}
}

func TestHandlers_onSubscribeUnsubscribe(t *testing.T) {
license, _ := security.ParseLicense(testLicense)
tests := []struct {
Expand Down Expand Up @@ -86,9 +136,9 @@ func TestHandlers_onSubscribeUnsubscribe(t *testing.T) {
{
channel: "0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/",
subCount: 0,
subErr: ErrNotFound,
subErr: ErrUnauthorized,
unsubCount: 0,
unsubErr: ErrNotFound,
unsubErr: ErrUnauthorized,
contractValid: true,
contractFound: false,
msg: "Contract not found case",
Expand Down Expand Up @@ -210,7 +260,7 @@ func TestHandlers_onPublish(t *testing.T) {
{
channel: "0Nq8SWbL8qoOKEDqh_ebBepug6cLLlWO/a/b/c/",
payload: "test",
err: ErrNotFound,
err: ErrUnauthorized,
contractValid: true,
contractFound: false,
msg: "Contract not found case",
Expand Down
2 changes: 1 addition & 1 deletion internal/broker/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ func (c *QueryManager) Query(query string, payload []byte) (message.Awaiter, err
message.Ssid{idSystem, idQuery, awaiter.id},
[]byte(channel),
payload,
))
), "")
return awaiter, nil
}

Expand Down
Loading