Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for additional AMQP URI query parameters #251

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .ci/versions.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"erlang": "26.1.1",
"rabbitmq": "3.12.6"
"erlang": "26.2.2",
"rabbitmq": "3.13.0"
}
45 changes: 38 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ func DefaultDial(connectionTimeout time.Duration) func(network, addr string) (ne
// scheme. It is equivalent to calling DialTLS(amqp, nil).
func Dial(url string) (*Connection, error) {
return DialConfig(url, Config{
Heartbeat: defaultHeartbeat,
Locale: defaultLocale,
Locale: defaultLocale,
})
}

Expand All @@ -169,7 +168,6 @@ func Dial(url string) (*Connection, error) {
// DialTLS uses the provided tls.Config when encountering an amqps:// scheme.
func DialTLS(url string, amqps *tls.Config) (*Connection, error) {
return DialConfig(url, Config{
Heartbeat: defaultHeartbeat,
TLSClientConfig: amqps,
Locale: defaultLocale,
})
Expand All @@ -186,7 +184,6 @@ func DialTLS(url string, amqps *tls.Config) (*Connection, error) {
// amqps:// scheme.
func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) {
return DialConfig(url, Config{
Heartbeat: defaultHeartbeat,
TLSClientConfig: amqps,
SASL: []Authentication{&ExternalAuth{}},
})
Expand All @@ -195,7 +192,9 @@ func DialTLS_ExternalAuth(url string, amqps *tls.Config) (*Connection, error) {
// DialConfig accepts a string in the AMQP URI format and a configuration for
// the transport and connection setup, returning a new Connection. Defaults to
// a server heartbeat interval of 10 seconds and sets the initial read deadline
// to 30 seconds.
// to 30 seconds. The heartbeat interval specified in the AMQP URI takes precedence
// over the value specified in the config. To disable heartbeats, you must use
// the AMQP URI and set heartbeat=0 there.
func DialConfig(url string, config Config) (*Connection, error) {
var err error
var conn net.Conn
Expand All @@ -206,18 +205,50 @@ func DialConfig(url string, config Config) (*Connection, error) {
}

if config.SASL == nil {
config.SASL = []Authentication{uri.PlainAuth()}
if uri.AuthMechanism != nil {
for _, identifier := range uri.AuthMechanism {
switch strings.ToUpper(identifier) {
case "PLAIN":
config.SASL = append(config.SASL, uri.PlainAuth())
case "AMQPLAIN":
config.SASL = append(config.SASL, uri.AMQPlainAuth())
case "EXTERNAL":
config.SASL = append(config.SASL, &ExternalAuth{})
default:
return nil, fmt.Errorf("unsupported auth_mechanism: %v", identifier)
}
}
} else {
config.SASL = []Authentication{uri.PlainAuth()}
}
}

if config.Vhost == "" {
config.Vhost = uri.Vhost
}

if uri.Heartbeat.hasValue {
config.Heartbeat = uri.Heartbeat.value
} else {
if config.Heartbeat == 0 {
config.Heartbeat = defaultHeartbeat
}
}

if config.ChannelMax == 0 {
config.ChannelMax = uri.ChannelMax
}

connectionTimeout := defaultConnectionTimeout
if uri.ConnectionTimeout != 0 {
connectionTimeout = time.Duration(uri.ConnectionTimeout) * time.Millisecond
}

addr := net.JoinHostPort(uri.Host, strconv.FormatInt(int64(uri.Port), 10))

dialer := config.Dial
if dialer == nil {
dialer = DefaultDial(defaultConnectionTimeout)
dialer = DefaultDial(connectionTimeout)
}

conn, err = dialer("tcp", addr)
Expand Down
13 changes: 13 additions & 0 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,16 @@ type bodyFrame struct {
}

func (f *bodyFrame) channel() uint16 { return f.ChannelId }

type heartbeatDuration struct {
value time.Duration
hasValue bool
}

func newHeartbeatDurationFromSeconds(s int) heartbeatDuration {
v := time.Duration(s) * time.Second
return heartbeatDuration{
value: v,
hasValue: true,
}
}
54 changes: 44 additions & 10 deletions uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package amqp091

import (
"errors"
"fmt"
"net"
"net/url"
"strconv"
Expand All @@ -32,16 +33,20 @@ var defaultURI = URI{

// URI represents a parsed AMQP URI string.
type URI struct {
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
Scheme string
Host string
Port int
Username string
Password string
Vhost string
CertFile string // client TLS auth - path to certificate (PEM)
CACertFile string // client TLS auth - path to CA certificate (PEM)
KeyFile string // client TLS auth - path to private key (PEM)
ServerName string // client TLS auth - server name
AuthMechanism []string
Heartbeat heartbeatDuration
ConnectionTimeout int
ChannelMax uint16
}

// ParseURI attempts to parse the given AMQP URI according to the spec.
Expand All @@ -62,6 +67,10 @@ type URI struct {
// keyfile: <path/to/client_key.pem>
// cacertfile: <path/to/ca.pem>
// server_name_indication: <server name>
// auth_mechanism: <one or more: plain, amqplain, external>
// heartbeat: <seconds (integer)>
// connection_timeout: <milliseconds (integer)>
// channel_max: <max number of channels (integer)>
//
// If cacertfile is not provided, system CA certificates will be used.
// Mutual TLS (client auth) will be enabled only in case keyfile AND certfile provided.
Expand Down Expand Up @@ -134,6 +143,31 @@ func ParseURI(uri string) (URI, error) {
builder.KeyFile = params.Get("keyfile")
builder.CACertFile = params.Get("cacertfile")
builder.ServerName = params.Get("server_name_indication")
builder.AuthMechanism = params["auth_mechanism"]

if params.Has("heartbeat") {
value, err := strconv.Atoi(params.Get("heartbeat"))
if err != nil {
return builder, fmt.Errorf("heartbeat is not an integer: %v", err)
}
builder.Heartbeat = newHeartbeatDurationFromSeconds(value)
}

if params.Has("connection_timeout") {
value, err := strconv.Atoi(params.Get("connection_timeout"))
if err != nil {
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
}
builder.ConnectionTimeout = value
}

if params.Has("channel_max") {
value, err := strconv.ParseUint(params.Get("channel_max"), 10, 16)
if err != nil {
return builder, fmt.Errorf("connection_timeout is not an integer: %v", err)
}
builder.ChannelMax = uint16(value)
}

return builder, nil
}
Expand Down
25 changes: 25 additions & 0 deletions uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
package amqp091

import (
"reflect"
"testing"
"time"
)

// Test matrix defined on http://www.rabbitmq.com/uri-spec.html
Expand Down Expand Up @@ -388,3 +390,26 @@ func TestURITLSConfig(t *testing.T) {
t.Fatal("Server name not set")
}
}

func TestURIParameters(t *testing.T) {
url := "amqps://foo.bar/?auth_mechanism=plain&auth_mechanism=amqpplain&heartbeat=2&connection_timeout=5000&channel_max=8"
uri, err := ParseURI(url)
if err != nil {
t.Fatal("Could not parse")
}
if !reflect.DeepEqual(uri.AuthMechanism, []string{"plain", "amqpplain"}) {
t.Fatal("AuthMechanism not set")
}
if !uri.Heartbeat.hasValue {
t.Fatal("Heartbeat not set")
}
if uri.Heartbeat.value != time.Duration(2)*time.Second {
t.Fatal("Heartbeat not set")
}
if uri.ConnectionTimeout != 5000 {
t.Fatal("ConnectionTimeout not set")
}
if uri.ChannelMax != 8 {
t.Fatal("ChannelMax name not set")
}
}
Loading