diff --git a/docs/getting-started.md b/docs/getting-started.md index 65fa05c3..7f0f8d02 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -71,7 +71,7 @@ To dispatch an update, the publisher (an application server, a web browser...) n ```http POST example.com HTTP/1.1 -Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InN1YnNjcmliZSI6WyJmb28iLCJiYXIiXSwicHVibGlzaCI6WyJmb28iXX19.afLx2f2ut3YgNVFStCx95Zm_UND1mZJ69OenXaDuZL8 +Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw topic=https://example.com/books/1&data={"foo": "updated value"} ``` @@ -95,8 +95,8 @@ const req = https.request({ path: '/.well-known/mercure', method: 'POST', headers: { - Authorization: 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InN1YnNjcmliZSI6WyJmb28iLCJiYXIiXSwicHVibGlzaCI6WyJmb28iXX19.afLx2f2ut3YgNVFStCx95Zm_UND1mZJ69OenXaDuZL8', - // the JWT must have a mercure.publish key containing an array of targets (can be empty for public updates) + Authorization: 'Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw', + // the JWT must have a mercure.publish key containing an array of topic selectors (can contain "*" for all topics, and be empty for public updates) // the JWT key must be shared between the hub and the server 'Content-Type': 'application/x-www-form-urlencoded', 'Content-Length': Buffer.byteLength(postData), @@ -109,7 +109,7 @@ req.end(); // but any HTTP client, written in any language, will be just fine. ``` -The JWT must contain a `publish` property containing an array of targets. This array can be empty to allow publishing anonymous updates only. To create and read JWTs try [jwt.io](https://jwt.io) ([demo token](https://jwt.io/#debugger-io?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InN1YnNjcmliZSI6WyJmb28iLCJiYXIiXSwicHVibGlzaCI6WyJmb28iXX19.afLx2f2ut3YgNVFStCx95Zm_UND1mZJ69OenXaDuZL8), key: `!ChangeMe!`). +The JWT must contain a `publish` property containing an array of topic selectors. This array can be empty to allow publishing anonymous updates only. The topic selector `*` can be used to allow publishing private updates for all topics. To create and read JWTs try [jwt.io](https://jwt.io) ([demo token](https://jwt.io/#debugger-io?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw), key: `!ChangeMe!`). ## Going Further diff --git a/docs/hub/config.md b/docs/hub/config.md index 65e3b7f7..38c822e3 100644 --- a/docs/hub/config.md +++ b/docs/hub/config.md @@ -28,7 +28,7 @@ When using environment variables, list must be space separated. As flags paramet | `cors_allowed_origins` | a list of allowed CORS origins, can be `*` for all | | `debug` | set to `true` to enable the debug mode, **dangerous, don't enable in production** (logs updates' content, why an update is not send to a specific subscriber and recovery stack traces) | | `demo` | set to `true` to enable the demo mode (automatically enabled when `debug=true`) | -| `dispatch_subscriptions` | set to `true` to dispatch updates when a subscription between the Hub and a subscriber is established or closed. The topic follows the template `https://mercure.rocks/subscriptions/{subscriptionID}`. To receive connection updates, subscribers must have `https://mercure.rocks/targets/subscriptions` or an URL matching the template `https://mercure.rocks/targets/subscriptions/{topic}` (`{topic}` is URL-encoded topic of the subscription) as targets | +| `dispatch_subscriptions` | set to `true` to dispatch private updates when a subscription between the Hub and a subscriber is established or closed. The topic follows the template `/.well-known/mercure/subscriptions/{subscriptionID}/{topic}` | | `heartbeat_interval` | interval between heartbeats (useful with some proxies, and old browsers), defaults to `15s`, set to `0s` to disable | | `jwt_key` | the JWT key to use for both publishers and subscribers | | `jwt_algorithm` | the JWT verification algorithm to use for both publishers and subscribers, e.g. HS256 (default) or RS512 | @@ -42,7 +42,6 @@ When using environment variables, list must be space separated. As flags paramet | `read_timeout` | maximum duration for reading the entire request, including the body, set to `0s` to disable (default), example: `2m` | | `subscriber_jwt_key` | must contain the secret key to valid subscribers' JWT, can be omitted if `jwt_key` is set | | `subscriber_jwt_algorithm` | the JWT verification algorithm to use for subscribers, e.g. HS256 (default) or RS512 | -| `subscriptions_include_ip` | set to `true` to include the subscriber's IP in the subscription update | | `transport_url` | URL representation of the history database. Provided database are `null` to disabled history, `bolt` to use [bbolt](https://github.com/etcd-io/bbolt) (example `bolt:///var/run/mercure.db?size=100&cleanup_frequency=0.4`), defaults to `bolt://updates.db` | | `update_buffer_size` | maximum number of updates to allow buffering before closing the connection | | `update_buffer_full_timeout` | time to wait before closing the connection after the buffer is full | diff --git a/docs/hub/troubleshooting.md b/docs/hub/troubleshooting.md index 53aeb971..81091b81 100644 --- a/docs/hub/troubleshooting.md +++ b/docs/hub/troubleshooting.md @@ -5,10 +5,10 @@ * Check the logs written by the hub on `stderr`, they contain the exact reason why the token has been rejected * Be sure to set a **secret key** (and not a JWT) in `JWT_KEY` (or in `SUBSCRIBER_JWT_KEY` and `PUBLISHER_JWT_KEY`) * If the secret key contains special characters, be sure to escape them properly, especially if you set the environment variable in a shell, or in a YAML file (Kubernetes...) -* The publisher always needs a valid JWT, even if `ALLOW_ANONYMOUS` is set to `1`, this JWT **must** have a property named `publish` and containing an array of targets ([example](https://jwt.io/#debugger-io?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOltdfX0.473isprbLWLjXmAaVZj6FIVkCdjn37SQpGjzWws-xa0)) -* The subscriber needs a valid JWT only if `ALLOW_ANONYMOUS` is set to `0` (default), or to subscribe to private updates, in this case the JWT **must** have a property named `subscribe` and containing an array of targets ([example](https://jwt.io/#debugger-io?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InN1YnNjcmliZSI6W119fQ.s-6MlTvJ6vpsZ7ftmz3dvWpZznRxnxI0KlrZOHVo8Qc)) +* The publisher always needs a valid JWT, even if `ALLOW_ANONYMOUS` is set to `1`, this JWT **must** have a property named `publish`. To dispatch private updates, the `publish` property must contain the list of topic selectors this publisher can use ([example](https://jwt.io/#debugger-io?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw)) +* The subscriber needs a valid JWT only if `ALLOW_ANONYMOUS` is set to `0` (default), or to subscribe to private updates, in this case the JWT **must** have a property named `subscribe` and containing an array of topic selectors ([example](https://jwt.io/#debugger-io?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw)) -For both the `publish` and `subscribe` properties, the array can be empty to publish only public updates, or set it to `["*"]` to allow accessing to all targets. +For both the `publish` and `subscribe` properties, the array can be empty to publish only public updates, or set it to `["*"]` to allow publishing updates for all topics. ## Browser Issues diff --git a/examples/chat-python-flask/chat.py b/examples/chat-python-flask/chat.py index 252d24fc..2bd0effe 100644 --- a/examples/chat-python-flask/chat.py +++ b/examples/chat-python-flask/chat.py @@ -20,7 +20,6 @@ JWT_KEY: the JWT key to use (must be shared with the Mercure hub) HUB_URL: the URL of the Mercure hub (default: http://localhost:3000/.well-known/mercure) TOPIC: the topic to use (default: http://example.com/chat) - TARGET: the target to use (default: chan) COOKIE_DOMAIN: the cookie domain (default: None) """ @@ -32,16 +31,14 @@ @app.route("/") def chat(): - targets = [os.environ.get('TARGET', 'chan')] + topic = os.environ.get('TOPIC', 'http://example.com/chat') token = jwt.encode( - {'mercure': {'subscribe': targets, 'publish': targets}}, + {'mercure': {'subscribe': [topics], 'publish': [topics]}}, os.environ.get('JWT_KEY', '!ChangeMe!'), algorithm='HS256' ) hub_url = os.environ.get('HUB_URL', 'http://localhost:3000/.well-known/mercure') - topic = os.environ.get('TOPIC', 'http://example.com/chat') - resp = make_response(render_template('chat.html', config={ 'hubURL': hub_url, 'topic': topic})) resp.set_cookie('mercureAuthorization', token, httponly=True, path='/.well-known/mercure', diff --git a/examples/publisher-node.js b/examples/publisher-node.js index 512c8e55..c42124a8 100644 --- a/examples/publisher-node.js +++ b/examples/publisher-node.js @@ -2,7 +2,7 @@ const http = require("http"); const querystring = require("querystring"); const demoJwt = - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InN1YnNjcmliZSI6WyJmb28iLCJiYXIiXSwicHVibGlzaCI6WyJmb28iXX19.afLx2f2ut3YgNVFStCx95Zm_UND1mZJ69OenXaDuZL8"; + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw"; const postData = querystring.stringify({ topic: "http://localhost:3000/demo/books/1.jsonld", diff --git a/examples/publisher-php.php b/examples/publisher-php.php index 25324826..087462c3 100644 --- a/examples/publisher-php.php +++ b/examples/publisher-php.php @@ -1,6 +1,6 @@ 'http://localhost:3000/demo/books/1.jsonld', diff --git a/examples/publisher-ruby.rb b/examples/publisher-ruby.rb index 7a43b3b3..99aa2c95 100644 --- a/examples/publisher-ruby.rb +++ b/examples/publisher-ruby.rb @@ -1,7 +1,7 @@ require 'json' require 'net/http' -token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InN1YnNjcmliZSI6WyJmb28iLCJiYXIiXSwicHVibGlzaCI6WyJmb28iXX19.afLx2f2ut3YgNVFStCx95Zm_UND1mZJ69OenXaDuZL8' +token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw' Net::HTTP.start('localhost', 3000) do |http| req = Net::HTTP::Post.new('/.well-known/mercure') diff --git a/gatling/LoadTest.scala b/gatling/LoadTest.scala index eeaffb6a..c64b76ef 100644 --- a/gatling/LoadTest.scala +++ b/gatling/LoadTest.scala @@ -27,7 +27,7 @@ class LoadTest extends Simulation { /** The hub URL */ val HubUrl = Properties.envOrElse("HUB_URL", "http://localhost:3001/.well-known/mercure") /** JWT to use to publish */ - val Jwt = Properties.envOrElse("JWT", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InN1YnNjcmliZSI6WyJmb28iLCJiYXIiXSwicHVibGlzaCI6WyJmb28iXX19.afLx2f2ut3YgNVFStCx95Zm_UND1mZJ69OenXaDuZL8") + val Jwt = Properties.envOrElse("JWT", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw") /** Number of concurrent subscribers initially connected */ val InitialSubscribers = Properties.envOrElse("INITIAL_SUBSCRIBERS", "100").toInt /** Additional subscribers rate (per second) */ diff --git a/hub/authorization.go b/hub/authorization.go index 7ea5fae3..a99822dc 100644 --- a/hub/authorization.go +++ b/hub/authorization.go @@ -20,32 +20,39 @@ type claims struct { } type mercureClaim struct { - Publish []string `json:"publish"` - Subscribe []string `json:"subscribe"` + Publish []string `json:"publish"` + Subscribe []string `json:"subscribe"` + Payload interface{} `json:"payload"` } type role int const ( - subscriberRole role = iota - publisherRole + roleSubscriber role = iota + rolePublisher ) var ( + // ErrInvalidAuthorizationHeader is returned when the Authorization header is invalid. ErrInvalidAuthorizationHeader = errors.New(`invalid "Authorization" HTTP header`) - ErrNoOrigin = errors.New(`an "Origin" or a "Referer" HTTP header must be present to use the cookie-based authorization mechanism`) - ErrOriginNotAllowed = errors.New("origin not allowed to post updates") - ErrUnexpectedSigningMethod = errors.New("unexpected signing method") - ErrInvalidJWT = errors.New("invalid JWT") - ErrPublicKey = errors.New("public key error") + // ErrNoOrigin is returned when the cookie authorization mechanism is used and no Origin nor Referer headers are presents. + ErrNoOrigin = errors.New(`an "Origin" or a "Referer" HTTP header must be present to use the cookie-based authorization mechanism`) + // ErrOriginNotAllowed is returned when the Origin is not allowed to post updates. + ErrOriginNotAllowed = errors.New("origin not allowed to post updates") + // ErrUnexpectedSigningMethod is returned when the signing JWT method is not supported. + ErrUnexpectedSigningMethod = errors.New("unexpected signing method") + // ErrInvalidJWT is returned when the JWT is invalid. + ErrInvalidJWT = errors.New("invalid JWT") + // ErrPublicKey is returned when there is an error with the public key. + ErrPublicKey = errors.New("public key error") ) func (h *Hub) getJWTKey(r role) []byte { var configKey string switch r { - case subscriberRole: + case roleSubscriber: configKey = "subscriber_jwt_key" - case publisherRole: + case rolePublisher: configKey = "publisher_jwt_key" } @@ -63,9 +70,9 @@ func (h *Hub) getJWTKey(r role) []byte { func (h *Hub) getJWTAlgorithm(r role) jwt.SigningMethod { var configKey string switch r { - case subscriberRole: + case roleSubscriber: configKey = "subscriber_jwt_algorithm" - case publisherRole: + case rolePublisher: configKey = "publisher_jwt_algorithm" } @@ -168,26 +175,36 @@ func validateJWT(encodedToken string, key []byte, signingAlgorithm jwt.SigningMe return nil, ErrInvalidJWT } -func authorizedTargets(claims *claims, publisher bool) (all bool, targets map[string]struct{}) { - if claims == nil { - return false, map[string]struct{}{} +func canReceive(s *topicSelectorStore, topics, topicSelectors []string) bool { + for _, topic := range topics { + for _, topicSelector := range topicSelectors { + if s.match(topic, topicSelector, true) { + return true + } + } } - var providedTargets []string - if publisher { - providedTargets = claims.Mercure.Publish - } else { - providedTargets = claims.Mercure.Subscribe - } + return false +} - authorizedTargets := make(map[string]struct{}, len(providedTargets)) - for _, target := range providedTargets { - if target == "*" { - return true, nil +func canDispatch(s *topicSelectorStore, topics, topicSelectors []string) bool { + for _, topic := range topics { + var matched bool + for _, topicSelector := range topicSelectors { + if topicSelector == "*" { + return true + } + + if s.match(topic, topicSelector, false) { + matched = true + break + } } - authorizedTargets[target] = struct{}{} + if !matched { + return false + } } - return false, authorizedTargets + return true } diff --git a/hub/authorization_test.go b/hub/authorization_test.go index d64e4479..4f70081b 100644 --- a/hub/authorization_test.go +++ b/hub/authorization_test.go @@ -344,50 +344,24 @@ func TestAuthorizeCookieOriginHasPriorityRsa(t *testing.T) { assert.Nil(t, err) } -func TestAuthorizedNilClaim(t *testing.T) { - all, targets := authorizedTargets(nil, true) - assert.False(t, all) - assert.Empty(t, targets) -} - -func TestAuthorizedTargetsPublisher(t *testing.T) { - c := &claims{Mercure: mercureClaim{ - Publish: []string{"foo", "bar"}, - }} - - all, targets := authorizedTargets(c, true) - assert.False(t, all) - assert.Equal(t, map[string]struct{}{"foo": {}, "bar": {}}, targets) -} - -func TestAuthorizedAllTargetsPublisher(t *testing.T) { - c := &claims{Mercure: mercureClaim{ - Publish: []string{"*"}, - }} - - all, targets := authorizedTargets(c, true) - assert.True(t, all) - assert.Empty(t, targets) -} - -func TestAuthorizedTargetsSubscriber(t *testing.T) { - c := &claims{Mercure: mercureClaim{ - Subscribe: []string{"foo", "bar"}, - }} - - all, targets := authorizedTargets(c, false) - assert.False(t, all) - assert.Equal(t, map[string]struct{}{"foo": {}, "bar": {}}, targets) -} - -func TestAuthorizedAllTargetsSubscriber(t *testing.T) { - c := &claims{Mercure: mercureClaim{ - Subscribe: []string{"*"}, - }} - - all, targets := authorizedTargets(c, false) - assert.True(t, all) - assert.Empty(t, targets) +func TestCanReceive(t *testing.T) { + s := newTopicSelectorStore() + assert.True(t, canReceive(s, []string{"foo", "bar"}, []string{"foo", "bar"})) + assert.True(t, canReceive(s, []string{"foo", "bar"}, []string{"bar"})) + assert.True(t, canReceive(s, []string{"foo", "bar"}, []string{"*"})) + assert.False(t, canReceive(s, []string{"foo", "bar"}, []string{})) + assert.False(t, canReceive(s, []string{"foo", "bar"}, []string{"baz"})) + assert.False(t, canReceive(s, []string{"foo", "bar"}, []string{"baz", "bat"})) +} + +func TestCanDispatch(t *testing.T) { + s := newTopicSelectorStore() + assert.True(t, canDispatch(s, []string{"foo", "bar"}, []string{"foo", "bar"})) + assert.True(t, canDispatch(s, []string{"foo", "bar"}, []string{"*"})) + assert.False(t, canDispatch(s, []string{"foo", "bar"}, []string{})) + assert.False(t, canDispatch(s, []string{"foo", "bar"}, []string{"foo"})) + assert.False(t, canDispatch(s, []string{"foo", "bar"}, []string{"baz"})) + assert.False(t, canDispatch(s, []string{"foo", "bar"}, []string{"baz", "bat"})) } func TestGetJWTKeyInvalid(t *testing.T) { @@ -396,12 +370,12 @@ func TestGetJWTKeyInvalid(t *testing.T) { h.config.Set("publisher_jwt_key", "") assert.PanicsWithValue(t, "one of these configuration parameters must be defined: [publisher_jwt_key jwt_key]", func() { - h.getJWTKey(publisherRole) + h.getJWTKey(rolePublisher) }) h.config.Set("subscriber_jwt_key", "") assert.PanicsWithValue(t, "one of these configuration parameters must be defined: [subscriber_jwt_key jwt_key]", func() { - h.getJWTKey(subscriberRole) + h.getJWTKey(roleSubscriber) }) } @@ -411,11 +385,11 @@ func TestGetJWTAlgorithmInvalid(t *testing.T) { h.config.Set("publisher_jwt_algorithm", "foo") assert.PanicsWithValue(t, "invalid signing method: foo", func() { - h.getJWTAlgorithm(publisherRole) + h.getJWTAlgorithm(rolePublisher) }) h.config.Set("subscriber_jwt_algorithm", "foo") assert.PanicsWithValue(t, "invalid signing method: foo", func() { - h.getJWTAlgorithm(subscriberRole) + h.getJWTAlgorithm(roleSubscriber) }) } diff --git a/hub/bolt_transport.go b/hub/bolt_transport.go index 6b0e96d3..37b61821 100644 --- a/hub/bolt_transport.go +++ b/hub/bolt_transport.go @@ -170,11 +170,13 @@ func (t *BoltTransport) dispatchHistory(s *Subscriber, toSeq uint64) { } c := b.Cursor() - afterFromID := false + afterFromID := s.LastEventID == "-1" + previousID := "-1" for k, v := c.First(); k != nil; k, v = c.Next() { if !afterFromID { if string(k[8:]) == s.LastEventID { afterFromID = true + previousID = "" } continue @@ -185,6 +187,7 @@ func (t *BoltTransport) dispatchHistory(s *Subscriber, toSeq uint64) { log.Error(fmt.Errorf("bolt history: %w", err)) return err } + update.PreviousID = previousID if !s.Dispatch(update, true) || (toSeq > 0 && binary.BigEndian.Uint64(k[:8]) >= toSeq) { return nil diff --git a/hub/bolt_transport_test.go b/hub/bolt_transport_test.go index 9e27b253..fbc9be78 100644 --- a/hub/bolt_transport_test.go +++ b/hub/bolt_transport_test.go @@ -26,9 +26,8 @@ func TestBoltTransportHistory(t *testing.T) { }) } - s := newSubscriber("8") + s := newSubscriber("8", newTopicSelectorStore()) s.Topics = topics - s.RawTopics = topics go s.start() err := transport.AddSubscriber(s) @@ -46,6 +45,39 @@ func TestBoltTransportHistory(t *testing.T) { } } +func TestBoltTransportRetrieveAllHistory(t *testing.T) { + u, _ := url.Parse("bolt://test.db") + transport, _ := NewBoltTransport(u) + defer transport.Close() + defer os.Remove("test.db") + + topics := []string{"https://example.com/foo"} + for i := 1; i <= 10; i++ { + transport.Dispatch(&Update{ + Event: Event{ID: strconv.Itoa(i)}, + Topics: topics, + }) + } + + s := newSubscriber("-1", newTopicSelectorStore()) + s.Topics = topics + go s.start() + + err := transport.AddSubscriber(s) + assert.Nil(t, err) + + var count int + for { + u := <-s.Receive() + // the reading loop must read all messages + count++ + assert.Equal(t, strconv.Itoa(count), u.ID) + if count == 10 { + return + } + } +} + func TestBoltTransportHistoryAndLive(t *testing.T) { u, _ := url.Parse("bolt://test.db") transport, _ := NewBoltTransport(u) @@ -60,9 +92,8 @@ func TestBoltTransportHistoryAndLive(t *testing.T) { }) } - s := newSubscriber("8") + s := newSubscriber("8", newTopicSelectorStore()) s.Topics = topics - s.RawTopics = topics go s.start() err := transport.AddSubscriber(s) @@ -148,7 +179,7 @@ func TestBoltTransportDoNotDispatchedUntilListen(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) go s.start() err := transport.AddSubscriber(s) @@ -184,9 +215,8 @@ func TestBoltTransportDispatch(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) s.Topics = []string{"https://example.com/foo"} - s.RawTopics = s.Topics go s.start() err := transport.AddSubscriber(s) @@ -209,9 +239,8 @@ func TestBoltTransportClosed(t *testing.T) { defer os.Remove("test.db") assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) s.Topics = []string{"https://example.com/foo"} - s.RawTopics = s.Topics go s.start() err := transport.AddSubscriber(s) @@ -237,12 +266,14 @@ func TestBoltCleanDisconnectedSubscribers(t *testing.T) { defer transport.Close() defer os.Remove("test.db") - s1 := newSubscriber("") + tss := newTopicSelectorStore() + + s1 := newSubscriber("", tss) go s1.start() err := transport.AddSubscriber(s1) require.Nil(t, err) - s2 := newSubscriber("") + s2 := newSubscriber("", tss) go s2.start() err = transport.AddSubscriber(s2) require.Nil(t, err) diff --git a/hub/config.go b/hub/config.go index 09406bdb..2de94c09 100644 --- a/hub/config.go +++ b/hub/config.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/viper" ) +// ErrInvalidConfig is returned when the configuration is invalid. var ErrInvalidConfig = errors.New("invalid config") // SetConfigDefaults sets defaults on a Viper instance. @@ -28,7 +29,6 @@ func SetConfigDefaults(v *viper.Viper) { v.SetDefault("use_forwarded_headers", false) v.SetDefault("demo", false) v.SetDefault("dispatch_subscriptions", false) - v.SetDefault("subscriptions_include_ip", false) v.SetDefault("metrics", false) v.SetDefault("metrics_login", "") v.SetDefault("metrics_password", "") diff --git a/hub/config_test.go b/hub/config_test.go index c15d7f62..ad3c4333 100644 --- a/hub/config_test.go +++ b/hub/config_test.go @@ -37,7 +37,7 @@ func TestSetFlags(t *testing.T) { fs := pflag.NewFlagSet("test", pflag.PanicOnError) SetFlags(fs, v) - assert.Subset(t, v.AllKeys(), []string{"cert_file", "compress", "demo", "jwt_algorithm", "transport_url", "acme_hosts", "acme_cert_dir", "subscriber_jwt_key", "log_format", "jwt_key", "allow_anonymous", "debug", "read_timeout", "publisher_jwt_algorithm", "write_timeout", "key_file", "use_forwarded_headers", "subscriber_jwt_algorithm", "addr", "publisher_jwt_key", "heartbeat_interval", "cors_allowed_origins", "publish_allowed_origins", "dispatch_subscriptions", "subscriptions_include_ip", "metrics", "metrics_login", "metrics_password", "dispatch_timeout"}) + assert.Subset(t, v.AllKeys(), []string{"cert_file", "compress", "demo", "jwt_algorithm", "transport_url", "acme_hosts", "acme_cert_dir", "subscriber_jwt_key", "log_format", "jwt_key", "allow_anonymous", "debug", "read_timeout", "publisher_jwt_algorithm", "write_timeout", "key_file", "use_forwarded_headers", "subscriber_jwt_algorithm", "addr", "publisher_jwt_key", "heartbeat_interval", "cors_allowed_origins", "publish_allowed_origins", "dispatch_subscriptions", "metrics", "metrics_login", "metrics_password", "dispatch_timeout"}) } func TestInitConfig(t *testing.T) { diff --git a/hub/hub.go b/hub/hub.go index 1725f551..ab898f35 100644 --- a/hub/hub.go +++ b/hub/hub.go @@ -3,32 +3,17 @@ package hub import ( "log" "net/http" - "sync" "github.com/spf13/viper" - "github.com/yosida95/uritemplate" ) -// uriTemplates caches uritemplate.Template to improve memory and CPU usage. -type uriTemplates struct { - sync.RWMutex - m map[string]*templateCache -} - -type templateCache struct { - // counter stores the number of subsribers currently using this topic - counter uint32 - // the uritemplate.Template instance, of nil if it's a raw string - template *uritemplate.Template -} - // Hub stores channels with clients currently subscribed and allows to dispatch updates. type Hub struct { - config *viper.Viper - transport Transport - server *http.Server - uriTemplates uriTemplates - metrics *Metrics + config *viper.Viper + transport Transport + server *http.Server + topicSelectorStore *topicSelectorStore + metrics *Metrics } // Stop stops disconnect all connected clients. @@ -56,7 +41,7 @@ func NewHubWithTransport(v *viper.Viper, t Transport) *Hub { v, t, nil, - uriTemplates{m: make(map[string]*templateCache)}, + newTopicSelectorStore(), NewMetrics(), } } diff --git a/hub/hub_test.go b/hub/hub_test.go index e13f4311..742a238a 100644 --- a/hub/hub_test.go +++ b/hub/hub_test.go @@ -87,16 +87,26 @@ func createDummyWithTransportAndConfig(t Transport, v *viper.Viper) *Hub { return NewHubWithTransport(v, t) } -func createDummyAuthorizedJWT(h *Hub, r role, targets []string) string { +func createDummyAuthorizedJWT(h *Hub, r role, topicSelectors []string) string { token := jwt.New(jwt.SigningMethodHS256) key := h.getJWTKey(r) switch r { - case publisherRole: - token.Claims = &claims{mercureClaim{Publish: targets}, jwt.StandardClaims{}} - - case subscriberRole: - token.Claims = &claims{mercureClaim{Subscribe: targets}, jwt.StandardClaims{}} + case rolePublisher: + token.Claims = &claims{mercureClaim{Publish: topicSelectors}, jwt.StandardClaims{}} + + case roleSubscriber: + var payload struct { + Foo string `json:"foo"` + } + payload.Foo = "bar" + token.Claims = &claims{ + mercureClaim{ + Subscribe: topicSelectors, + Payload: payload, + }, + jwt.StandardClaims{}, + } } tokenString, _ := token.SignedString(key) diff --git a/hub/log.go b/hub/log.go index 12877749..bf01221b 100644 --- a/hub/log.go +++ b/hub/log.go @@ -11,7 +11,7 @@ func addUpdateFields(f log.Fields, u *Update, debug bool) log.Fields { f["event_type"] = u.Type f["event_retry"] = u.Retry f["update_topics"] = u.Topics - f["update_targets"] = targetsMapToSlice(u.Targets) + f["update_private"] = u.Private if debug { f["update_data"] = u.Data @@ -29,18 +29,6 @@ func createFields(u *Update, s *Subscriber) log.Fields { return f } -func targetsMapToSlice(t map[string]struct{}) []string { - targets := make([]string, len(t)) - - var i int - for target := range t { - targets[i] = target - i++ - } - - return targets -} - // InitLogrus configures the global logger. func InitLogrus() { if viper.GetBool("debug") { diff --git a/hub/metrics_test.go b/hub/metrics_test.go index 42c6f864..fc75f677 100644 --- a/hub/metrics_test.go +++ b/hub/metrics_test.go @@ -11,13 +11,15 @@ import ( func TestNumberOfRunningSubscribers(t *testing.T) { m := NewMetrics() - s1 := newSubscriber("") + sst := newTopicSelectorStore() + + s1 := newSubscriber("", sst) s1.Topics = []string{"topic1", "topic2"} m.NewSubscriber(s1) assertGaugeLabelValue(t, 1.0, m.subscribers, "topic1") assertGaugeLabelValue(t, 1.0, m.subscribers, "topic2") - s2 := newSubscriber("") + s2 := newSubscriber("", sst) s2.Topics = []string{"topic2"} m.NewSubscriber(s2) assertGaugeLabelValue(t, 1.0, m.subscribers, "topic1") @@ -35,13 +37,15 @@ func TestNumberOfRunningSubscribers(t *testing.T) { func TestTotalNumberOfHandledSubscribers(t *testing.T) { m := NewMetrics() - s1 := newSubscriber("") + sst := newTopicSelectorStore() + + s1 := newSubscriber("", sst) s1.Topics = []string{"topic1", "topic2"} m.NewSubscriber(s1) assertCounterValue(t, 1.0, m.subscribersTotal, "topic1") assertCounterValue(t, 1.0, m.subscribersTotal, "topic2") - s2 := newSubscriber("") + s2 := newSubscriber("", sst) s2.Topics = []string{"topic2"} m.NewSubscriber(s2) assertCounterValue(t, 1.0, m.subscribersTotal, "topic1") diff --git a/hub/publish.go b/hub/publish.go index 00eb67ef..d7121837 100644 --- a/hub/publish.go +++ b/hub/publish.go @@ -1,8 +1,6 @@ package hub import ( - "errors" - "fmt" "io" "net/http" "strconv" @@ -10,11 +8,9 @@ import ( log "github.com/sirupsen/logrus" ) -var ErrTargetNotAuthorized = errors.New("target not authorized") - // PublishHandler allows publisher to broadcast updates to all subscribers. func (h *Hub) PublishHandler(w http.ResponseWriter, r *http.Request) { - claims, err := authorize(r, h.getJWTKey(publisherRole), h.getJWTAlgorithm(publisherRole), h.config.GetStringSlice("publish_allowed_origins")) + claims, err := authorize(r, h.getJWTKey(rolePublisher), h.getJWTAlgorithm(rolePublisher), h.config.GetStringSlice("publish_allowed_origins")) if err != nil || claims == nil || claims.Mercure.Publish == nil { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) log.WithFields(log.Fields{"remote_addr": r.RemoteAddr}).Info(err) @@ -32,18 +28,6 @@ func (h *Hub) PublishHandler(w http.ResponseWriter, r *http.Request) { return } - data := r.PostForm.Get("data") - if data == "" { - http.Error(w, "Missing \"data\" parameter", http.StatusBadRequest) - return - } - - targets, err := getAuthorizedTargets(claims, r.PostForm["target"]) - if err != nil { - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - return - } - var retry uint64 retryString := r.PostForm.Get("retry") if retryString != "" { @@ -54,7 +38,22 @@ func (h *Hub) PublishHandler(w http.ResponseWriter, r *http.Request) { } } - u := newUpdate(Event{data, r.PostForm.Get("id"), r.PostForm.Get("type"), retry}, topics, targets) + private := len(r.PostForm["private"]) != 0 + if private && !canDispatch(h.topicSelectorStore, topics, claims.Mercure.Publish) { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + u := newUpdate( + topics, + private, + Event{ + r.PostForm.Get("data"), + r.PostForm.Get("id"), + r.PostForm.Get("type"), + retry, + }, + ) // Broadcast the update if err := h.transport.Dispatch(u); err != nil { @@ -66,19 +65,3 @@ func (h *Hub) PublishHandler(w http.ResponseWriter, r *http.Request) { h.metrics.NewUpdate(u) } - -func getAuthorizedTargets(claims *claims, t []string) (map[string]struct{}, error) { - authorizedAlltargets, authorizedTargets := authorizedTargets(claims, true) - targets := make(map[string]struct{}, len(t)) - for _, t := range t { - if !authorizedAlltargets { - _, ok := authorizedTargets[t] - if !ok { - return nil, fmt.Errorf("%q: %w", t, ErrTargetNotAuthorized) - } - } - targets[t] = struct{}{} - } - - return targets, nil -} diff --git a/hub/publish_test.go b/hub/publish_test.go index abe42eb7..ac5998e0 100644 --- a/hub/publish_test.go +++ b/hub/publish_test.go @@ -62,7 +62,7 @@ func TestPublishBadContentType(t *testing.T) { hub := createDummy() req := httptest.NewRequest("POST", defaultHubURL, nil) - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, []string{})) req.Header.Add("Content-Type", "text/plain; boundary=") w := httptest.NewRecorder() hub.PublishHandler(w, req) @@ -77,7 +77,7 @@ func TestPublishNoTopic(t *testing.T) { hub := createDummy() req := httptest.NewRequest("POST", defaultHubURL, nil) - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, []string{})) w := httptest.NewRecorder() hub.PublishHandler(w, req) @@ -96,7 +96,7 @@ func TestPublishNoData(t *testing.T) { req := httptest.NewRequest("POST", defaultHubURL, strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, []string{"*"})) w := httptest.NewRecorder() hub.PublishHandler(w, req) @@ -104,8 +104,7 @@ func TestPublishNoData(t *testing.T) { resp := w.Result() defer resp.Body.Close() - assert.Equal(t, http.StatusBadRequest, resp.StatusCode) - assert.Equal(t, "Missing \"data\" parameter\n", w.Body.String()) + assert.Equal(t, http.StatusOK, resp.StatusCode) } func TestPublishInvalidRetry(t *testing.T) { @@ -118,7 +117,7 @@ func TestPublishInvalidRetry(t *testing.T) { req := httptest.NewRequest("POST", defaultHubURL, strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, []string{})) w := httptest.NewRecorder() hub.PublishHandler(w, req) @@ -130,17 +129,17 @@ func TestPublishInvalidRetry(t *testing.T) { assert.Equal(t, "Invalid \"retry\" parameter\n", w.Body.String()) } -func TestPublishNotAuthorizedTarget(t *testing.T) { +func TestPublishNotAuthorizedTopicSelector(t *testing.T) { hub := createDummy() form := url.Values{} form.Add("topic", "http://example.com/books/1") form.Add("data", "foo") - form.Add("target", "not-allowed") + form.Add("private", "on") req := httptest.NewRequest("POST", defaultHubURL, strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, publisherRole, []string{"foo"})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, []string{"foo"})) w := httptest.NewRecorder() hub.PublishHandler(w, req) @@ -155,10 +154,9 @@ func TestPublishOK(t *testing.T) { hub := createDummy() defer hub.Stop() - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) s.Topics = []string{"http://example.com/books/1"} - s.RawTopics = s.Topics - s.Targets = map[string]struct{}{"foo": {}} + s.Claims = &claims{Mercure: mercureClaim{Subscribe: s.Topics}} go s.start() err := hub.transport.AddSubscriber(s) @@ -172,22 +170,20 @@ func TestPublishOK(t *testing.T) { assert.True(t, ok) require.NotNil(t, u) assert.Equal(t, "id", u.ID) - assert.Equal(t, []string{"http://example.com/books/1"}, u.Topics) + assert.Equal(t, s.Topics, u.Topics) assert.Equal(t, "Hello!", u.Data) - assert.Equal(t, struct{}{}, u.Targets["foo"]) - assert.Equal(t, struct{}{}, u.Targets["bar"]) + assert.True(t, u.Private) }(&wg) form := url.Values{} form.Add("id", "id") form.Add("topic", "http://example.com/books/1") form.Add("data", "Hello!") - form.Add("target", "foo") - form.Add("target", "bar") + form.Add("private", "on") req := httptest.NewRequest("POST", defaultHubURL, strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, publisherRole, []string{"foo", "bar"})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, s.Topics)) w := httptest.NewRecorder() hub.PublishHandler(w, req) @@ -206,10 +202,8 @@ func TestPublishGenerateUUID(t *testing.T) { h := createDummy() defer h.Stop() - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) s.Topics = []string{"http://example.com/books/1"} - s.RawTopics = s.Topics - s.Targets = map[string]struct{}{"foo": {}} go s.start() h.transport.AddSubscriber(s) @@ -231,8 +225,7 @@ func TestPublishGenerateUUID(t *testing.T) { req := httptest.NewRequest("POST", defaultHubURL, strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - //req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, publisherRole, []string{})}) - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, rolePublisher, []string{})) w := httptest.NewRecorder() h.PublishHandler(w, req) @@ -264,12 +257,11 @@ func TestPublishWithErrorInTransport(t *testing.T) { form.Add("id", "id") form.Add("topic", "http://example.com/books/1") form.Add("data", "Hello!") - form.Add("target", "foo") - form.Add("target", "bar") + form.Add("private", "on") req := httptest.NewRequest("POST", defaultHubURL, strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, publisherRole, []string{"foo", "bar"})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, rolePublisher, []string{"foo", "http://example.com/books/1"})) w := httptest.NewRecorder() hub.PublishHandler(w, req) diff --git a/hub/server_test.go b/hub/server_test.go index df37ac0d..a3ef7d13 100644 --- a/hub/server_test.go +++ b/hub/server_test.go @@ -43,7 +43,7 @@ func TestForwardedHeaders(t *testing.T) { req, _ := http.NewRequest("POST", testURL, strings.NewReader(body.Encode())) req.Header.Add("X-Forwarded-For", "192.0.2.1") req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, rolePublisher, []string{})) resp2, err := client.Do(req) require.Nil(t, err) @@ -159,7 +159,7 @@ func TestServe(t *testing.T) { body := url.Values{"topic": {"http://example.com/foo/1", "http://example.com/alt/1"}, "data": {"hello"}, "id": {"first"}} req, _ := http.NewRequest("POST", testURL, strings.NewReader(body.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, rolePublisher, []string{})) resp2, err := client.Do(req) require.Nil(t, err) @@ -227,7 +227,7 @@ func TestClientClosesThenReconnects(t *testing.T) { req, err := http.NewRequest("POST", testURL, strings.NewReader(body.Encode())) require.Nil(t, err) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(h, rolePublisher, []string{})) resp, err := client.Do(req) require.Nil(t, err) @@ -437,7 +437,7 @@ func (s *testServer) newSubscriber(topic string, keepAlive bool) { func (s *testServer) publish(body url.Values) { req, _ := http.NewRequest("POST", testURL, strings.NewReader(body.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(s.h, publisherRole, []string{})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(s.h, rolePublisher, []string{})) resp, err := s.client.Do(req) require.Nil(s.t, err) diff --git a/hub/subscribe.go b/hub/subscribe.go index c31d7881..b05956bb 100644 --- a/hub/subscribe.go +++ b/hub/subscribe.go @@ -4,23 +4,19 @@ import ( "encoding/json" "fmt" "io" - "net" "net/http" "net/url" - "strings" "time" log "github.com/sirupsen/logrus" - "github.com/yosida95/uritemplate" ) type subscription struct { - ID string `json:"@id"` - Type string `json:"@type"` - Topic string `json:"topic"` - Active bool `json:"active"` - mercureClaim - Address string `json:"address,omitempty"` + ID string `json:"@id"` + Type string `json:"@type"` + Topic string `json:"topic"` + Active bool `json:"active"` + Payload interface{} `json:"payload,omitempty"` } // SubscribeHandler creates a keep alive connection and sends the events to the subscribers. @@ -61,6 +57,10 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { } timer.Reset(hearthbeatInterval) case update := <-s.Receive(): + if update.PreviousID != "" { + w.Header().Set("Last-Event-ID", update.PreviousID) + } + if !h.write(w, s, newSerializedUpdate(update).event) { return } @@ -77,14 +77,14 @@ func (h *Hub) SubscribeHandler(w http.ResponseWriter, r *http.Request) { // registerSubscriber initializes the connection. func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request, debug bool) *Subscriber { - s := newSubscriber(retrieveLastEventID(r)) + s := newSubscriber(retrieveLastEventID(r), h.topicSelectorStore) s.Debug = debug s.LogFields["remote_addr"] = r.RemoteAddr - claims, err := authorize(r, h.getJWTKey(subscriberRole), h.getJWTAlgorithm(subscriberRole), nil) + claims, err := authorize(r, h.getJWTKey(roleSubscriber), h.getJWTAlgorithm(roleSubscriber), nil) if claims != nil { s.Claims = claims - s.LogFields["subscriber_targets"] = claims.Mercure.Subscribe + s.LogFields["subscriber_topic_selectors"] = claims.Mercure.Subscribe } if err != nil || (claims == nil && !h.config.GetBool("allow_anonymous")) { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) @@ -99,16 +99,11 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request, debug b } s.LogFields["subscriber_topics"] = s.Topics - s.RawTopics, s.TemplateTopics = h.parseTopics(s.Topics) s.EscapedTopics = escapeTopics(s.Topics) - s.AllTargets, s.Targets = authorizedTargets(claims, false) s.RemoteAddr = r.RemoteAddr go s.start() - if h.config.GetBool("subscriptions_include_ip") { - s.RemoteHost, _, _ = net.SplitHostPort(r.RemoteAddr) - } h.dispatchSubscriptionUpdate(s, true) if err := h.transport.AddSubscriber(s); err != nil { http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) @@ -116,7 +111,7 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request, debug b log.WithFields(s.LogFields).Error(err) return nil } - sendHeaders(w) + sendHeaders(w, s.LastEventID == "") log.WithFields(s.LogFields).Info("New subscriber") h.metrics.NewSubscriber(s) @@ -124,42 +119,8 @@ func (h *Hub) registerSubscriber(w http.ResponseWriter, r *http.Request, debug b return s } -func (h *Hub) parseTopics(topics []string) (rawTopics []string, templateTopics []*uritemplate.Template) { - rawTopics = make([]string, 0, len(topics)) - templateTopics = make([]*uritemplate.Template, 0, len(topics)) - for _, topic := range topics { - if tpl := h.getURITemplate(topic); tpl == nil { - rawTopics = append(rawTopics, topic) - } else { - templateTopics = append(templateTopics, tpl) - } - } - - return rawTopics, templateTopics -} - -// getURITemplate retrieves or creates the uritemplate.Template associated with this topic, or nil if it's not a template. -func (h *Hub) getURITemplate(topic string) *uritemplate.Template { - var tpl *uritemplate.Template - h.uriTemplates.Lock() - defer h.uriTemplates.Unlock() - if tplCache, ok := h.uriTemplates.m[topic]; ok { - tpl = tplCache.template - tplCache.counter++ - - return tpl - } - if strings.Contains(topic, "{") { // If it's definitely not an URI template, skip to save some resources - tpl, _ = uritemplate.New(topic) // Returns nil in case of error, will be considered as a raw string - } - - h.uriTemplates.m[topic] = &templateCache{1, tpl} - - return tpl -} - // sendHeaders sends correct HTTP headers to create a keep-alive connection. -func sendHeaders(w http.ResponseWriter) { +func sendHeaders(w http.ResponseWriter, flush bool) { // Keep alive, useful only for HTTP 1 clients https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Keep-Alive w.Header().Set("Connection", "keep-alive") @@ -174,10 +135,12 @@ func sendHeaders(w http.ResponseWriter) { // NGINX support https://www.nginx.com/resources/wiki/start/topics/examples/x-accel/#x-accel-buffering w.Header().Set("X-Accel-Buffering", "no") - // Write a comment in the body - // Go currently doesn't provide a better way to flush the headers - fmt.Fprint(w, ":\n") - w.(http.Flusher).Flush() + if flush { + // Write a comment in the body + // Go currently doesn't provide a better way to flush the headers + fmt.Fprint(w, ":\n") + w.(http.Flusher).Flush() + } } // retrieveLastEventID extracts the Last-Event-ID from the corresponding HTTP header with a fallback on the query parameter. @@ -222,24 +185,6 @@ func (h *Hub) shutdown(s *Subscriber) { h.dispatchSubscriptionUpdate(s, false) log.WithFields(s.LogFields).Info("Subscriber disconnected") h.metrics.SubscriberDisconnect(s) - - // Remove unused uritemplate.Template instances from memory. - keys := make([]string, 0, len(s.RawTopics)+len(s.TemplateTopics)) - copy(s.RawTopics, keys) - for _, uriTemplate := range s.TemplateTopics { - keys = append(keys, uriTemplate.Raw()) - } - - h.uriTemplates.Lock() - for _, key := range keys { - counter := h.uriTemplates.m[key].counter - if counter == 0 { - delete(h.uriTemplates.m, key) - } else { - h.uriTemplates.m[key].counter = counter - 1 - } - } - h.uriTemplates.Unlock() } func (h *Hub) dispatchSubscriptionUpdate(s *Subscriber, active bool) { @@ -249,21 +194,14 @@ func (h *Hub) dispatchSubscriptionUpdate(s *Subscriber, active bool) { for k, topic := range s.Topics { connection := &subscription{ - ID: "https://mercure.rocks/subscriptions/" + s.EscapedTopics[k] + "/" + s.EscapedID, - Type: "https://mercure.rocks/Subscription", - Topic: topic, - Active: active, - Address: s.RemoteHost, + ID: "/.well-known/mercure/subscriptions/" + s.EscapedID + "/" + s.EscapedTopics[k], + Type: "https://mercure.rocks/Subscription", + Topic: topic, + Active: active, } - if s.Claims != nil { - connection.mercureClaim = s.Claims.Mercure - } - if s.Claims == nil || connection.mercureClaim.Publish == nil { - connection.mercureClaim.Publish = []string{} - } - if s.Claims == nil || connection.mercureClaim.Subscribe == nil { - connection.mercureClaim.Subscribe = []string{} + if s.Claims != nil && s.Claims.Mercure.Payload != nil { + connection.Payload = s.Claims.Mercure.Payload } json, err := json.MarshalIndent(connection, "", " ") @@ -271,12 +209,7 @@ func (h *Hub) dispatchSubscriptionUpdate(s *Subscriber, active bool) { panic(err) } - u := newUpdate( - Event{Data: string(json)}, - []string{connection.ID}, - map[string]struct{}{"https://mercure.rocks/targets/subscriptions": {}, "https://mercure.rocks/targets/subscriptions/" + s.EscapedTopics[k]: {}}, - ) - + u := newUpdate([]string{connection.ID}, true, Event{Data: string(json)}) h.transport.Dispatch(u) } } diff --git a/hub/subscribe_test.go b/hub/subscribe_test.go index 4a363f75..4bc21693 100644 --- a/hub/subscribe_test.go +++ b/hub/subscribe_test.go @@ -33,6 +33,7 @@ func (m *responseWriterMock) WriteHeader(statusCode int) { } type responseTester struct { + header http.Header body string expectedStatusCode int expectedBody string @@ -41,7 +42,11 @@ type responseTester struct { } func (rt *responseTester) Header() http.Header { - return http.Header{} + if rt.header == nil { + return http.Header{} + } + + return rt.header } func (rt *responseTester) Write(buf []byte) (int, error) { @@ -287,7 +292,7 @@ func TestUnsubscribe(t *testing.T) { wg.Wait() } -func TestSubscribeTarget(t *testing.T) { +func TestSubscribePrivate(t *testing.T) { hub := createDummy() hub.config.Set("debug", true) s, _ := hub.transport.(*LocalTransport) @@ -303,19 +308,19 @@ func TestSubscribeTarget(t *testing.T) { } hub.transport.Dispatch(&Update{ - Targets: map[string]struct{}{"baz": {}}, Topics: []string{"http://example.com/reviews/21"}, Event: Event{Data: "Foo", ID: "a"}, + Private: true, }) hub.transport.Dispatch(&Update{ - Targets: map[string]struct{}{}, Topics: []string{"http://example.com/reviews/22"}, Event: Event{Data: "Hello World", ID: "b", Type: "test"}, + Private: true, }) hub.transport.Dispatch(&Update{ - Targets: map[string]struct{}{"hello": {}, "bar": {}}, Topics: []string{"http://example.com/reviews/23"}, Event: Event{Data: "Great", ID: "c", Retry: 1}, + Private: true, }) return } @@ -323,7 +328,7 @@ func TestSubscribeTarget(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) req := httptest.NewRequest("GET", defaultHubURL+"?topic=http://example.com/reviews/{id}", nil).WithContext(ctx) - req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, subscriberRole, []string{"foo", "bar"})}) + req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, roleSubscriber, []string{"http://example.com/reviews/22", "http://example.com/reviews/23"})}) w := &responseTester{ expectedStatusCode: http.StatusOK, @@ -339,7 +344,6 @@ func TestSubscribeTarget(t *testing.T) { func TestSubscriptionEvents(t *testing.T) { hub := createDummy() hub.config.Set("dispatch_subscriptions", true) - hub.config.Set("subscriptions_include_ip", true) var wg sync.WaitGroup ctx1, cancel1 := context.WithCancel(context.Background()) @@ -348,8 +352,8 @@ func TestSubscriptionEvents(t *testing.T) { go func() { // Authorized to receive connection events defer wg.Done() - req := httptest.NewRequest("GET", defaultHubURL+"?topic=https://mercure.rocks/subscriptions/{topic}/{connectionID}", nil).WithContext(ctx1) - req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, subscriberRole, []string{"https://mercure.rocks/targets/subscriptions"})}) + req := httptest.NewRequest("GET", defaultHubURL+"?topic=/.well-known/mercure/subscriptions/{subscriptionID}/{topic}", nil).WithContext(ctx1) + req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, roleSubscriber, []string{"/.well-known/mercure/subscriptions/{subscriptionID}/{topic}"})}) w := httptest.NewRecorder() hub.SubscribeHandler(w, req) @@ -359,21 +363,21 @@ func TestSubscriptionEvents(t *testing.T) { assert.Equal(t, http.StatusOK, resp.StatusCode) bodyContent := string(body) - assert.Contains(t, bodyContent, `data: "@id": "https://mercure.rocks/subscriptions/https%3A%2F%2Fexample.com/`) + assert.Contains(t, bodyContent, `data: "@id": "/.well-known/mercure/subscriptions/`) + assert.Contains(t, bodyContent, `/https%3A%2F%2Fexample.com`) assert.Contains(t, bodyContent, `data: "@type": "https://mercure.rocks/Subscription",`) assert.Contains(t, bodyContent, `data: "topic": "https://example.com",`) - assert.Contains(t, bodyContent, `data: "publish": [],`) - assert.Contains(t, bodyContent, `data: "subscribe": []`) assert.Contains(t, bodyContent, `data: "active": true,`) assert.Contains(t, bodyContent, `data: "active": false,`) - assert.Contains(t, bodyContent, `data: "address": "`) + assert.Contains(t, bodyContent, `data: "payload": {`) + assert.Contains(t, bodyContent, `data: "foo": "bar"`) }() go func() { // Not authorized to receive connection events defer wg.Done() - req := httptest.NewRequest("GET", defaultHubURL+"?topic=https://mercure.rocks/subscriptions/{topic}/{connectionID}", nil).WithContext(ctx2) - req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, subscriberRole, []string{})}) + req := httptest.NewRequest("GET", defaultHubURL+"?topic=/.well-known/mercure/subscriptions/{subscriptionID}/{topic}", nil).WithContext(ctx2) + req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, roleSubscriber, []string{})}) w := httptest.NewRecorder() hub.SubscribeHandler(w, req) @@ -401,7 +405,7 @@ func TestSubscriptionEvents(t *testing.T) { ctx, cancelRequest2 := context.WithCancel(context.Background()) req := httptest.NewRequest("GET", defaultHubURL+"?topic=https://example.com", nil).WithContext(ctx) - req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, subscriberRole, []string{})}) + req.AddCookie(&http.Cookie{Name: "mercureAuthorization", Value: createDummyAuthorizedJWT(hub, roleSubscriber, []string{})}) w := &responseTester{ expectedStatusCode: http.StatusOK, @@ -419,7 +423,7 @@ func TestSubscriptionEvents(t *testing.T) { hub.Stop() } -func TestSubscribeAllTargets(t *testing.T) { +func TestSubscribeAll(t *testing.T) { hub := createDummy() s, _ := hub.transport.(*LocalTransport) @@ -434,14 +438,14 @@ func TestSubscribeAllTargets(t *testing.T) { } hub.transport.Dispatch(&Update{ - Targets: map[string]struct{}{"foo": {}}, Topics: []string{"http://example.com/reviews/21"}, Event: Event{Data: "Foo", ID: "a"}, + Private: true, }) hub.transport.Dispatch(&Update{ - Targets: map[string]struct{}{"bar": {}}, Topics: []string{"http://example.com/reviews/22"}, Event: Event{Data: "Hello World", ID: "b", Type: "test"}, + Private: true, }) return @@ -450,7 +454,7 @@ func TestSubscribeAllTargets(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) req := httptest.NewRequest("GET", defaultHubURL+"?topic=http://example.com/reviews/{id}", nil).WithContext(ctx) - req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, subscriberRole, []string{"random", "*"})) + req.Header.Add("Authorization", "Bearer "+createDummyAuthorizedJWT(hub, roleSubscriber, []string{"random", "*"})) w := &responseTester{ expectedStatusCode: http.StatusOK, @@ -497,7 +501,7 @@ func TestSendMissedEvents(t *testing.T) { w := &responseTester{ expectedStatusCode: http.StatusOK, - expectedBody: ":\nid: b\ndata: d2\n\n", + expectedBody: "id: b\ndata: d2\n\n", t: t, cancel: cancel, } @@ -514,12 +518,79 @@ func TestSendMissedEvents(t *testing.T) { w := &responseTester{ expectedStatusCode: http.StatusOK, - expectedBody: ":\nid: b\ndata: d2\n\n", + expectedBody: "id: b\ndata: d2\n\n", + t: t, + cancel: cancel, + } + + hub.SubscribeHandler(w, req) + }() + + wg.Wait() + hub.Stop() +} + +func TestSendAllEvents(t *testing.T) { + u, _ := url.Parse("bolt://test.db") + transport, _ := NewBoltTransport(u) + defer transport.Close() + defer os.Remove("test.db") + + hub := createDummyWithTransportAndConfig(transport, viper.New()) + + transport.Dispatch(&Update{ + Topics: []string{"http://example.com/foos/a"}, + Event: Event{ + ID: "a", + Data: "d1", + }, + }) + transport.Dispatch(&Update{ + Topics: []string{"http://example.com/foos/b"}, + Event: Event{ + ID: "b", + Data: "d2", + }, + }) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest("GET", defaultHubURL+"?topic=http://example.com/foos/{id}&Last-Event-ID=-1", nil).WithContext(ctx) + + w := &responseTester{ + header: http.Header{}, + expectedStatusCode: http.StatusOK, + expectedBody: "id: a\ndata: d1\n\nid: b\ndata: d2\n\n", + t: t, + cancel: cancel, + } + + hub.SubscribeHandler(w, req) + assert.Equal(t, "-1", w.Header().Get("Last-Event-ID")) + }() + + go func() { + defer wg.Done() + + ctx, cancel := context.WithCancel(context.Background()) + req := httptest.NewRequest("GET", defaultHubURL+"?topic=http://example.com/foos/{id}", nil).WithContext(ctx) + req.Header.Add("Last-Event-ID", "-1") + + w := &responseTester{ + header: http.Header{}, + expectedStatusCode: http.StatusOK, + expectedBody: "id: a\ndata: d1\n\nid: b\ndata: d2\n\n", t: t, cancel: cancel, } hub.SubscribeHandler(w, req) + assert.Equal(t, "-1", w.Header().Get("Last-Event-ID")) }() wg.Wait() diff --git a/hub/subscriber.go b/hub/subscriber.go index cae93371..25e3ac83 100644 --- a/hub/subscriber.go +++ b/hub/subscriber.go @@ -5,7 +5,6 @@ import ( "github.com/gofrs/uuid" log "github.com/sirupsen/logrus" - "github.com/yosida95/uritemplate" ) type updateSource struct { @@ -15,29 +14,24 @@ type updateSource struct { // Subscriber represents a client subscribed to a list of topics. type Subscriber struct { - ID string - EscapedID string - Claims *claims - Targets map[string]struct{} - Topics []string - EscapedTopics []string - RawTopics []string - TemplateTopics []*uritemplate.Template - LastEventID string - RemoteAddr string - RemoteHost string - LogFields log.Fields - AllTargets bool - Debug bool - - out chan *Update - disconnected chan struct{} - matchCache map[string]bool - history updateSource - live updateSource + ID string + EscapedID string + Claims *claims + Topics []string + EscapedTopics []string + LastEventID string + RemoteAddr string + LogFields log.Fields + Debug bool + + out chan *Update + disconnected chan struct{} + history updateSource + live updateSource + topicSelectorStore *topicSelectorStore } -func newSubscriber(lastEventID string) *Subscriber { +func newSubscriber(lastEventID string, uriTemplates *topicSelectorStore) *Subscriber { id := "urn:uuid:" + uuid.Must(uuid.NewV4()).String() s := &Subscriber{ ID: id, @@ -47,11 +41,11 @@ func newSubscriber(lastEventID string) *Subscriber { "subscriber_id": id, "last_event_id": lastEventID, }, - history: updateSource{}, - live: updateSource{in: make(chan *Update)}, - out: make(chan *Update), - disconnected: make(chan struct{}), - matchCache: make(map[string]bool), + history: updateSource{}, + live: updateSource{in: make(chan *Update)}, + out: make(chan *Update), + disconnected: make(chan struct{}), + topicSelectorStore: uriTemplates, } if lastEventID != "" { @@ -64,6 +58,7 @@ func newSubscriber(lastEventID string) *Subscriber { // start stores incoming updates in an history and a live buffer and dispatch them. // Updates coming from the history are always dispatched first. func (s *Subscriber) start() { + defer s.cleanup() for { select { case <-s.disconnected: @@ -91,6 +86,13 @@ func (s *Subscriber) start() { } } +func (s *Subscriber) cleanup() { + s.topicSelectorStore.cleanup(s.Topics) + if s.Claims != nil && s.Claims.Mercure.Subscribe != nil { + s.topicSelectorStore.cleanup(s.Claims.Mercure.Subscribe) + } +} + // outChan returns the out channel if buffers aren't empty, or nil to block. func (s *Subscriber) outChan() chan<- *Update { if len(s.live.buffer) > 0 || len(s.history.buffer) > 0 { @@ -163,62 +165,15 @@ func (s *Subscriber) Disconnected() <-chan struct{} { // CanDispatch checks if an update can be dispatched to this subsriber. func (s *Subscriber) CanDispatch(u *Update) bool { - if !s.IsAuthorized(u) { - log.WithFields(createFields(u, s)).Debug("Subscriber not authorized to receive this update (no targets matching)") + if !canReceive(s.topicSelectorStore, u.Topics, s.Topics) { + log.WithFields(createFields(u, s)).Debug("Subscriber has not subscribed to this update") return false } - if !s.IsSubscribed(u) { - log.WithFields(createFields(u, s)).Debug("Subscriber has not subscribed to this update (no topics matching)") + if u.Private && (s.Claims == nil || s.Claims.Mercure.Subscribe == nil || !canReceive(s.topicSelectorStore, u.Topics, s.Claims.Mercure.Subscribe)) { + log.WithFields(createFields(u, s)).Debug("Subscriber not authorized to receive this update") return false } return true } - -// IsAuthorized checks if the subscriber can access to at least one of the update's intended targets. -// Don't forget to also call IsSubscribed. -func (s *Subscriber) IsAuthorized(u *Update) bool { - if s.AllTargets || len(u.Targets) == 0 { - return true - } - - for t := range s.Targets { - if _, ok := u.Targets[t]; ok { - return true - } - } - - return false -} - -// IsSubscribed checks if the subscriber has subscribed to this update. -// Don't forget to also call IsAuthorized. -func (s *Subscriber) IsSubscribed(u *Update) bool { - for _, ut := range u.Topics { - if match, ok := s.matchCache[ut]; ok { - if match { - return true - } - continue - } - - for _, rt := range s.RawTopics { - if ut == rt { - s.matchCache[ut] = true - return true - } - } - - for _, tt := range s.TemplateTopics { - if tt.Match(ut) != nil { - s.matchCache[ut] = true - return true - } - } - - s.matchCache[ut] = false - } - - return false -} diff --git a/hub/subscriber_test.go b/hub/subscriber_test.go index 835119a1..eacaf05e 100644 --- a/hub/subscriber_test.go +++ b/hub/subscriber_test.go @@ -7,25 +7,9 @@ import ( "github.com/stretchr/testify/assert" ) -func TestIsSubscribed(t *testing.T) { - s := newSubscriber("") - s.Topics = []string{"foo", "bar"} - s.RawTopics = s.Topics - - assert.Len(t, s.matchCache, 0) - assert.False(t, s.IsSubscribed(&Update{Topics: []string{"baz", "bat"}})) - assert.True(t, s.IsSubscribed(&Update{Topics: []string{"baz", "bar"}})) - assert.Len(t, s.matchCache, 3) - - // assert cache is used - assert.True(t, s.IsSubscribed(&Update{Topics: []string{"bar", "qux"}})) - assert.Len(t, s.matchCache, 3) -} - func TestDispatch(t *testing.T) { - s := newSubscriber("1") + s := newSubscriber("1", newTopicSelectorStore()) s.Topics = []string{"http://example.com"} - s.RawTopics = s.Topics go s.start() defer s.Disconnect() @@ -44,7 +28,7 @@ func TestDispatch(t *testing.T) { } func TestDisconnect(t *testing.T) { - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) s.Disconnect() // can be called two times without crashing s.Disconnect() diff --git a/hub/topic_selector.go b/hub/topic_selector.go new file mode 100644 index 00000000..3e000814 --- /dev/null +++ b/hub/topic_selector.go @@ -0,0 +1,99 @@ +package hub + +import ( + "strings" + "sync" + + "github.com/yosida95/uritemplate" +) + +type selector struct { + sync.RWMutex + // counter stores the number of subsribers currently using this topic + counter uint32 + // the uritemplate.Template instance, of nil if it's a raw string + template *uritemplate.Template + matchCache map[string]bool +} + +// topicSelectorStore caches uritemplate.Template to improve memory and CPU usage. +type topicSelectorStore struct { + sync.RWMutex + m map[string]*selector +} + +func newTopicSelectorStore() *topicSelectorStore { + return &topicSelectorStore{m: make(map[string]*selector)} +} + +func (tss *topicSelectorStore) match(topic, topicSelector string, addToCache bool) bool { + // Always do an exact matching comparison first + // Also check if the topic selector is the reserved keyword * + if topicSelector == "*" || topic == topicSelector { + return true + } + + templateStore := tss.getTemplateStore(topicSelector, addToCache) + templateStore.RLock() + match, ok := templateStore.matchCache[topic] + templateStore.RUnlock() + if ok { + return match + } + + match = templateStore.template != nil && templateStore.template.Match(topic) != nil + templateStore.Lock() + templateStore.matchCache[topic] = match + templateStore.Unlock() + + return match +} + +// getTemplateStore retrieves or creates the uritemplate.Template associated with this topic, or nil if it's not a template. +func (tss *topicSelectorStore) getTemplateStore(topicSelector string, addToCache bool) *selector { + if addToCache { + tss.Lock() + defer tss.Unlock() + } else { + tss.RLock() + } + + s, ok := tss.m[topicSelector] + if !addToCache { + tss.RUnlock() + } + if ok { + if addToCache { + s.counter++ + } + + return s + } + + s = &selector{matchCache: make(map[string]bool)} + if strings.Contains(topicSelector, "{") { // If it's definitely not an URI template, skip to save some resources + s.template, _ = uritemplate.New(topicSelector) // Returns nil in case of error, will be considered as a raw string + } + + if addToCache { + tss.m[topicSelector] = s + } + + return s +} + +// Remove unused uritemplate.Template instances from memory. +func (tss *topicSelectorStore) cleanup(topics []string) { + tss.Lock() + defer tss.Unlock() + for _, topic := range topics { + if tc, ok := tss.m[topic]; ok { + if tc.counter == 0 { + delete(tss.m, topic) + continue + } + + tc.counter-- + } + } +} diff --git a/hub/topic_selector_test.go b/hub/topic_selector_test.go new file mode 100644 index 00000000..8ebbc53b --- /dev/null +++ b/hub/topic_selector_test.go @@ -0,0 +1,32 @@ +package hub + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMatch(t *testing.T) { + tss := newTopicSelectorStore() + + assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/{foo}/bar", false)) + assert.Empty(t, tss.m) + assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/{foo}/bar", true)) + assert.False(t, tss.match("https://example.com/foo/bar/baz", "https://example.com/{foo}/bar", true)) + assert.NotNil(t, tss.m["https://example.com/{foo}/bar"].template) + assert.True(t, tss.m["https://example.com/{foo}/bar"].matchCache["https://example.com/foo/bar"]) + assert.False(t, tss.m["https://example.com/{foo}/bar"].matchCache["https://example.com/foo/bar/baz"]) + assert.Equal(t, tss.m["https://example.com/{foo}/bar"].counter, uint32(1)) + + assert.True(t, tss.match("https://example.com/kevin/dunglas", "https://example.com/{fistname}/{lastname}", true)) + assert.True(t, tss.match("https://example.com/foo/bar", "*", true)) + assert.True(t, tss.match("https://example.com/foo/bar", "https://example.com/foo/bar", true)) + assert.True(t, tss.match("foo", "foo", true)) + assert.False(t, tss.match("foo", "bar", true)) + + tss.cleanup([]string{"https://example.com/{foo}/bar", "https://example.com/{fistname}/{lastname}", "bar"}) + assert.Len(t, tss.m, 1) + + tss.cleanup([]string{"https://example.com/{foo}/bar", "https://example.com/{fistname}/{lastname}"}) + assert.Empty(t, tss.m) +} diff --git a/hub/transport.go b/hub/transport.go index 8514fe46..bd3872fe 100644 --- a/hub/transport.go +++ b/hub/transport.go @@ -22,7 +22,7 @@ type Transport interface { } var ( - // ErrInvalidTransportDSN is returned when the Transport's DSN is invalid + // ErrInvalidTransportDSN is returned when the Transport's DSN is invalid. ErrInvalidTransportDSN = errors.New("invalid transport DSN") // ErrClosedTransport is returned by the Transport's Dispatch and AddSubscriber methods after a call to Close. ErrClosedTransport = errors.New("hub: read/write on closed Transport") diff --git a/hub/transport_test.go b/hub/transport_test.go index 9d3b0e76..b568a038 100644 --- a/hub/transport_test.go +++ b/hub/transport_test.go @@ -15,16 +15,12 @@ func TestLocalTransportDoNotDispatchUntilListen(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - u := &Update{ - Topics: []string{"http://example.com/books/1"}, - } + u := &Update{Topics: []string{"http://example.com/books/1"}} err := transport.Dispatch(u) require.Nil(t, err) - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) s.Topics = u.Topics - s.RawTopics = u.Topics - s.Targets = map[string]struct{}{"foo": {}} go s.start() err = transport.AddSubscriber(s) @@ -57,9 +53,8 @@ func TestLocalTransportDispatch(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) s.Topics = []string{"http://example.com/foo"} - s.RawTopics = s.Topics go s.start() err := transport.AddSubscriber(s) @@ -79,14 +74,16 @@ func TestLocalTransportClosed(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber("") + tss := newTopicSelectorStore() + + s := newSubscriber("", tss) err := transport.AddSubscriber(s) require.Nil(t, err) err = transport.Close() assert.Nil(t, err) - err = transport.AddSubscriber(newSubscriber("")) + err = transport.AddSubscriber(newSubscriber("", tss)) assert.Equal(t, err, ErrClosedTransport) err = transport.Dispatch(&Update{}) @@ -100,13 +97,15 @@ func TestLiveCleanDisconnectedSubscribers(t *testing.T) { transport := NewLocalTransport() defer transport.Close() - s1 := newSubscriber("") + tss := newTopicSelectorStore() + + s1 := newSubscriber("", tss) go s1.start() err := transport.AddSubscriber(s1) require.Nil(t, err) - s2 := newSubscriber("") + s2 := newSubscriber("", tss) go s2.start() err = transport.AddSubscriber(s2) @@ -132,9 +131,8 @@ func TestLiveReading(t *testing.T) { defer transport.Close() assert.Implements(t, (*Transport)(nil), transport) - s := newSubscriber("") + s := newSubscriber("", newTopicSelectorStore()) s.Topics = []string{"https://example.com"} - s.RawTopics = s.Topics go s.start() err := transport.AddSubscriber(s) diff --git a/hub/update.go b/hub/update.go index 38226882..115134e9 100644 --- a/hub/update.go +++ b/hub/update.go @@ -4,13 +4,17 @@ import "github.com/gofrs/uuid" // Update represents an update to send to subscribers. type Update struct { - // The target audience. - Targets map[string]struct{} - // The topics' Internationalized Resource Identifier (RFC3987) (will most likely be URLs). // The first one is the canonical IRI, while next ones are alternate IRIs. Topics []string + // Private updates can only be dispatched to subscribers authorized to receive them. + Private bool + + // PreviousID contains the ID of the previous update + // This value must be sent only if the request Last-Event-ID cannot be found, and only on the first update available + PreviousID string + // The Server-Sent Event to send. Event } @@ -20,11 +24,11 @@ type serializedUpdate struct { event string } -func newUpdate(event Event, topics []string, targets map[string]struct{}) *Update { +func newUpdate(topics []string, private bool, event Event) *Update { u := &Update{ - Event: event, Topics: topics, - Targets: targets, + Private: private, + Event: event, } if u.ID == "" { u.ID = "urn:uuid:" + uuid.Must(uuid.NewV4()).String() diff --git a/hub/update_test.go b/hub/update_test.go new file mode 100644 index 00000000..cb0d274d --- /dev/null +++ b/hub/update_test.go @@ -0,0 +1,22 @@ +package hub + +import ( + "strings" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/assert" +) + +func TestNewUpdate(t *testing.T) { + u := newUpdate([]string{"foo"}, true, Event{Retry: 3}) + + assert.Equal(t, []string{"foo"}, u.Topics) + assert.True(t, u.Private) + assert.Equal(t, uint64(3), u.Retry) + + assert.True(t, strings.HasPrefix(u.ID, "urn:uuid:")) + + _, err := uuid.FromString(strings.TrimPrefix(u.ID, "urn:uuid:")) + assert.Nil(t, err) +} diff --git a/public/app.js b/public/app.js index 53dbf3ab..b8fd3190 100644 --- a/public/app.js +++ b/public/app.js @@ -4,7 +4,7 @@ const origin = window.location.origin; const defaultTopic = origin + "/demo/books/1.jsonld"; const defaultJwt = - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InN1YnNjcmliZSI6WyJmb28iLCJiYXIiXSwicHVibGlzaCI6WyJmb28iXX19.afLx2f2ut3YgNVFStCx95Zm_UND1mZJ69OenXaDuZL8"; + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtZXJjdXJlIjp7InB1Ymxpc2giOlsiKiJdLCJzdWJzY3JpYmUiOlsiaHR0cHM6Ly9leGFtcGxlLmNvbS9teS1wcml2YXRlLXRvcGljIiwiaHR0cDovL2xvY2FsaG9zdDozMDAwL2RlbW8vYm9va3Mve2lkfS5qc29ubGQiXSwicGF5bG9hZCI6eyJ1c2VyIjoiaHR0cHM6Ly9leGFtcGxlLmNvbS91c2Vycy9kdW5nbGFzIiwicmVtb3RlX2FkZHIiOiIxMjcuMC4wLjEifX19.bRUavgS2H9GyCHq7eoPUL_rZm2L7fGujtyyzUhiOsnw"; const updates = document.querySelector("#updates"); const settingsForm = document.forms.settings; @@ -146,7 +146,7 @@ foo`; publishForm.onsubmit = function (e) { e.preventDefault(); const { - elements: { topics, data, targets, id, type, retry }, + elements: { topics, data, priv, id, type, retry }, } = this; const body = new URLSearchParams({ @@ -157,10 +157,7 @@ foo`; }); topics.value.split("\n").forEach((topic) => body.append("topic", topic)); - targets.value !== "" && - targets.value - .split("\n") - .forEach((target) => body.append("target", target)); + priv.checked && body.append("private", "on") const opt = { method: "POST", body }; if (settingsForm.authorization.value === "header") diff --git a/public/index.html b/public/index.html index b90a1dd6..68d7769d 100644 --- a/public/index.html +++ b/public/index.html @@ -62,11 +62,11 @@
{
"mercure": {
- "subscribe": ["list of authorized targets, * for all, omit for public only"],
- "publish": ["list of authorized targets, * for all, omit to not allow to publish"]
+ "subscribe": ["list of topic selectors, * for all, omit for public only"],
+ "publish": ["list of topic selectors, * for all, omit to not allow to publish"]
}
}
!ChangeMe!
)
@@ -119,6 +119,7 @@
One URI template or string per line (try the tester).
+ Use *
to subscribe to all topics.
Examples:
- First line: canonical IRI.
- Next lines: alternate IRIs.
One string per line.