Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Pre-registered links and simplified client-server setup #183

Merged
merged 6 commits into from
Jan 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions internal/broker/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type Conn struct {
service *Service // The service for this connection.
subs *message.Counters // The subscriptions for this connection.
measurer stats.Measurer // The measurer to use for monitoring.
links map[string]string // The map of all pre-authorized links.
}

// NewConn creates a new connection.
Expand All @@ -55,6 +56,7 @@ func (s *Service) newConn(t net.Conn) *Conn {
socket: t,
subs: message.NewCounters(),
measurer: s.measurer,
links: map[string]string{},
}

// Generate a globally unique id as well
Expand Down Expand Up @@ -137,11 +139,13 @@ func (c *Conn) onReceive(msg mqtt.Message) error {

// We got an attempt to connect to MQTT.
case mqtt.TypeOfConnect:
packet := msg.(*mqtt.Connect)
c.username = string(packet.Username)
var result uint8
if !c.onConnect(msg.(*mqtt.Connect)) {
result = 0x05 // Unauthorized
}

// Write the ack
ack := mqtt.Connack{ReturnCode: 0x00}
ack := mqtt.Connack{ReturnCode: result}
if _, err := ack.EncodeTo(c.socket); err != nil {
return err
}
Expand Down
213 changes: 137 additions & 76 deletions internal/broker/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,61 +16,79 @@ package broker

import (
"encoding/json"
"fmt"
"regexp"
"strings"
"time"

"github.com/emitter-io/emitter/internal/message"
"github.com/emitter-io/emitter/internal/network/mqtt"
"github.com/emitter-io/emitter/internal/provider/contract"
"github.com/emitter-io/emitter/internal/provider/logging"
"github.com/emitter-io/emitter/internal/security"
"github.com/emitter-io/emitter/internal/security/hash"
"github.com/kelindar/binary"
)

const (
requestKeygen = 548658350 // hash("keygen")
requestPresence = 3869262148 // hash("presence")
requestLink = 2667034312 // hash("link")
requestMe = 2539734036 // hash("me")
)

// ------------------------------------------------------------------------------------
var (
shortcut = regexp.MustCompile("^[a-zA-Z0-9]{1,2}$")
)

// OnSubscribe is a handler for MQTT Subscribe events.
func (c *Conn) onSubscribe(mqttTopic []byte) *Error {
// ------------------------------------------------------------------------------------

// Parse the channel
channel := security.ParseChannel(mqttTopic)
if channel.ChannelType == security.ChannelInvalid {
return ErrBadRequest
}
// Authorize attempts to authorize a channel with its key
func (c *Conn) authorize(channel *security.Channel, permission uint8) (contract.Contract, security.Key, bool) {

// Attempt to parse the key
key, err := c.service.Cipher.DecryptKey(channel.Key)
if err != nil || key.IsExpired() {
return ErrUnauthorized
return nil, nil, false
}

// Attempt to fetch the contract using the key. Underneath, it's cached.
contract, contractFound := c.service.contracts.Get(key.Contract())
if !contractFound {
return ErrNotFound
if !contractFound || !contract.Validate(key) || !key.HasPermission(permission) || !key.ValidateChannel(channel) {
return nil, nil, false
}

// Validate the contract
if !contract.Validate(key) {
return ErrUnauthorized
}
// Return the contract and the key
return contract, key, true
}

// Check if the key has the permission to read from here
if !key.HasPermission(security.AllowRead) {
return ErrUnauthorized
// ------------------------------------------------------------------------------------

// onConnect handles the connection authorization
func (c *Conn) onConnect(packet *mqtt.Connect) bool {
c.username = string(packet.Username)
return true
}

// ------------------------------------------------------------------------------------

// OnSubscribe is a handler for MQTT Subscribe events.
func (c *Conn) onSubscribe(mqttTopic []byte) *Error {

// Parse the channel
channel := security.ParseChannel(mqttTopic)
if channel.ChannelType == security.ChannelInvalid {
return ErrBadRequest
}

// Check if the key has the permission for the required channel
if !key.ValidateChannel(channel) {
// Check the authorization and permissions
contract, key, allowed := c.authorize(channel, security.AllowRead)
if !allowed {
return ErrUnauthorized
}

// Subscribe the client to the channel
ssid := message.NewSsid(key.Contract(), channel)
ssid := message.NewSsid(key.Contract(), channel.Query)
c.Subscribe(ssid, channel.Channel)

// In case of ttl, check the key provides the permission to store (soft permission)
Expand Down Expand Up @@ -105,35 +123,14 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error {
return ErrBadRequest
}

// Attempt to parse the key
key, err := c.service.Cipher.DecryptKey(channel.Key)
if err != nil || key.IsExpired() {
return ErrUnauthorized
}

// Attempt to fetch the contract using the key. Underneath, it's cached.
contract, contractFound := c.service.contracts.Get(key.Contract())
if !contractFound {
return ErrNotFound
}

// Validate the contract
if !contract.Validate(key) {
return ErrUnauthorized
}

// Check if the key has the permission to read from here
if !key.HasPermission(security.AllowRead) {
return ErrUnauthorized
}

// Check if the key has the permission for the required channel
if !key.ValidateChannel(channel) {
// Check the authorization and permissions
contract, key, allowed := c.authorize(channel, security.AllowRead)
if !allowed {
return ErrUnauthorized
}

// Unsubscribe the client from the channel
ssid := message.NewSsid(key.Contract(), channel)
ssid := message.NewSsid(key.Contract(), channel.Query)
c.Unsubscribe(ssid, channel.Channel)
c.track(contract)
return nil
Expand All @@ -143,8 +140,13 @@ func (c *Conn) onUnsubscribe(mqttTopic []byte) *Error {

// OnPublish is a handler for MQTT Publish events.
func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error {
exclude := ""
if len(mqttTopic) <= 2 && c.links != nil {
mqttTopic = []byte(c.links[string(mqttTopic)])
exclude = c.ID()
}

// Parse the channel
// Make sure we have a valid channel
channel := security.ParseChannel(mqttTopic)
if channel.ChannelType == security.ChannelInvalid {
return ErrBadRequest
Expand All @@ -161,36 +163,15 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error {
return nil
}

// Attempt to parse the key
key, err := c.service.Cipher.DecryptKey(channel.Key)
if err != nil || key.IsExpired() {
return ErrUnauthorized
}

// Attempt to fetch the contract using the key. Underneath, it's cached.
contract, contractFound := c.service.contracts.Get(key.Contract())
if !contractFound {
return ErrNotFound
}

// Validate the contract
if !contract.Validate(key) {
return ErrUnauthorized
}

// Check if the key has the permission to write here
if !key.HasPermission(security.AllowWrite) {
return ErrUnauthorized
}

// Check if the key has the permission for the required channel
if !key.ValidateChannel(channel) {
// Check the authorization and permissions
contract, key, allowed := c.authorize(channel, security.AllowWrite)
if !allowed {
return ErrUnauthorized
}

// Create a new message
msg := message.New(
message.NewSsid(key.Contract(), channel),
message.NewSsid(key.Contract(), channel.Query),
channel.Channel,
payload,
)
Expand All @@ -202,7 +183,7 @@ func (c *Conn) onPublish(mqttTopic []byte, payload []byte) *Error {
}

// Iterate through all subscribers and send them the message
size := c.service.publish(msg)
size := c.service.publish(msg, exclude)

// Write the monitoring information
c.track(contract)
Expand Down Expand Up @@ -241,18 +222,98 @@ func (c *Conn) onEmitterRequest(channel *security.Channel, payload []byte) (ok b
case requestMe:
resp, ok = c.onMe()
return
case requestLink:
resp, ok = c.onLink(payload)
return
default:
return
}
}

// ------------------------------------------------------------------------------------

// onLink handles a request to create a link.
func (c *Conn) onLink(payload []byte) (interface{}, bool) {
var request linkRequest
if err := json.Unmarshal(payload, &request); err != nil {
return ErrBadRequest, false
}

// Check whether the name is a valid shortcut name
if !shortcut.Match([]byte(request.Name)) {
return ErrLinkInvalid, false
}

// Make the channel from the request or try to make a private one
channel := security.MakeChannel(request.Key, request.Channel)
if request.Private {
channel = c.makePrivateChannel(request.Key, request.Channel)
}

// Ensures that the channel requested is valid
if channel == nil || channel.ChannelType == security.ChannelInvalid {
return ErrBadRequest, false
}

// Create the link with the name and set the full channel to it
c.links[request.Name] = channel.String()

// If an auto-subscribe was requested and the key has read permissions, subscribe
if _, key, allowed := c.authorize(channel, security.AllowRead); allowed && request.Subscribe {
c.Subscribe(message.NewSsid(key.Contract(), channel.Query), channel.Channel)
}

return &linkResponse{
Status: 200,
Name: request.Name,
Channel: channel.SafeString(),
}, true
}

// makePrivateChannel creates a private channel and an appropriate key.
func (c *Conn) makePrivateChannel(chanKey, chanName string) *security.Channel {
channel := security.MakeChannel(chanKey, chanName)
if channel.ChannelType != security.ChannelStatic {
return nil
}

// Make sure we can actually extend it
_, key, allowed := c.authorize(channel, security.AllowExtend)
if !allowed {
return nil
}

// Create a new key for the private link
target := fmt.Sprintf("%s%s/", channel.Channel, c.ID())
if err := key.SetTarget(target); err != nil {
return nil
}

// Encrypt the key for storing
encryptedKey, err := c.service.Cipher.EncryptKey(key)
if err != nil {
return nil
}

// Create the private channel
channel.Channel = []byte(target)
channel.Query = append(channel.Query, hash.Of([]byte(c.ID())))
channel.Key = []byte(encryptedKey)
return channel
}

// ------------------------------------------------------------------------------------

// OnMe is a handler that returns information to the connection.
func (c *Conn) onMe() (interface{}, bool) {
// Success, return the response
links := make(map[string]string)
for k, v := range c.links {
links[k] = security.ParseChannel([]byte(v)).SafeString()
}

return &meResponse{
ID: c.ID(),
ID: c.ID(),
Links: links,
}, true
}

Expand Down Expand Up @@ -407,7 +468,7 @@ func (c *Conn) onPresence(payload []byte) (interface{}, bool) {
}

// Create the ssid for the presence
ssid := message.NewSsid(key.Contract(), channel)
ssid := message.NewSsid(key.Contract(), channel.Query)

// Check if the client is interested in subscribing/unsubscribing from changes.
if msg.Changes {
Expand Down
Loading