diff --git a/internal/io/mqtt/connection_pool.go b/internal/io/mqtt/connection_pool.go index cc944a9c8a..e63e12587c 100644 --- a/internal/io/mqtt/connection_pool.go +++ b/internal/io/mqtt/connection_pool.go @@ -15,7 +15,6 @@ package mqtt import ( - "fmt" "sync" "github.com/lf-edge/ekuiper/contract/v2/api" @@ -26,45 +25,35 @@ var ( lock sync.RWMutex ) -func GetConnection(ctx api.StreamContext, props map[string]any) (*Connection, error) { - var clientId string - if sid, ok := props["connectionSelector"]; ok { - if s, ok := sid.(string); ok { - clientId = s - } else { - return nil, fmt.Errorf("connectionSelector value: %v is not string", sid) - } - } - if clientId == "" { +func GetConnection(ctx api.StreamContext, selId string, props map[string]any) (*Connection, error) { + if selId == "" { return CreateClient(ctx, "", props) } lock.Lock() defer lock.Unlock() - if conn, ok := connectionPool[clientId]; ok { + if conn, ok := connectionPool[selId]; ok { conn.attach() return conn, nil } else { - cli, err := CreateClient(ctx, clientId, props) + cli, err := CreateClient(ctx, selId, props) if err != nil { return nil, err } - connectionPool[clientId] = cli + connectionPool[selId] = cli return cli, nil } } -func DetachConnection(clientId string, subscribedTopic string) { +func DetachConnection(conn *Connection, selId string, subscribedTopic string) { + var closed bool + if subscribedTopic != "" { + closed = conn.detachSub(subscribedTopic) + } else { + closed = conn.detachPub() + } lock.Lock() defer lock.Unlock() - if conn, ok := connectionPool[clientId]; ok { - var closed bool - if subscribedTopic != "" { - closed = conn.detachSub(subscribedTopic) - } else { - closed = conn.detachPub() - } - if closed { - delete(connectionPool, clientId) - } + if _, ok := connectionPool[selId]; closed && ok { + delete(connectionPool, selId) } } diff --git a/internal/io/mqtt/connection_test.go b/internal/io/mqtt/connection_test.go index 0b6e22aa93..42704f3080 100644 --- a/internal/io/mqtt/connection_test.go +++ b/internal/io/mqtt/connection_test.go @@ -33,17 +33,17 @@ func TestConnectionLC(t *testing.T) { "server": "abc", } ctx := mockContext.NewMockContext("test", "op") - connShared1, err := GetConnection(ctx, propsShared) + connShared1, err := GetConnection(ctx, "mqtt.localConnection", propsShared) assert.NoError(t, err) assert.Equal(t, 1, len(connectionPool)) - _, err = GetConnection(ctx, propsNormal) + _, err = GetConnection(ctx, "", propsNormal) assert.NoError(t, err) assert.Equal(t, 1, len(connectionPool)) - connShared2, err := GetConnection(ctx, propsShared) + connShared2, err := GetConnection(ctx, "mqtt.localConnection", propsShared) assert.NoError(t, err) assert.Equal(t, 1, len(connectionPool)) assert.Equal(t, connShared1, connShared2) - _, err = GetConnection(ctx, propsInvalid) + _, err = GetConnection(ctx, "", propsInvalid) assert.Error(t, err) assert.Equal(t, err.Error(), "found error when connecting for abc: network Error : dial tcp: address abc: missing port in address") assert.Equal(t, 1, len(connectionPool)) @@ -56,9 +56,9 @@ func TestConnectionLC(t *testing.T) { // Test subscribe in the connector test. - DetachConnection("mqtt.localConnection", "") + DetachConnection(connShared1, "mqtt.localConnection", "") assert.Equal(t, 1, len(connectionPool)) - DetachConnection("mqtt.localConnection", "") + DetachConnection(connShared2, "mqtt.localConnection", "") assert.Equal(t, 0, len(connectionPool)) } diff --git a/internal/io/mqtt/sink.go b/internal/io/mqtt/sink.go index 634f094679..55aba5965a 100644 --- a/internal/io/mqtt/sink.go +++ b/internal/io/mqtt/sink.go @@ -25,10 +25,10 @@ import ( // AdConf is the advanced configuration for the mqtt sink type AdConf struct { - Tpc string `json:"topic"` - Qos byte `json:"qos"` - Retained bool `json:"retained"` - + Tpc string `json:"topic"` + Qos byte `json:"qos"` + Retained bool `json:"retained"` + SelId string `json:"connectionSelector"` ResendTopic string `json:"resendDestination"` } @@ -67,7 +67,7 @@ func (ms *Sink) Provision(_ api.StreamContext, ps map[string]any) error { func (ms *Sink) Connect(ctx api.StreamContext) error { ctx.GetLogger().Infof("Connecting to mqtt server") - cli, err := GetConnection(ctx, ms.config) + cli, err := GetConnection(ctx, ms.adconf.SelId, ms.config) ms.cli = cli return err } @@ -100,8 +100,7 @@ func (ms *Sink) Collect(ctx api.StreamContext, item api.RawTuple) error { func (ms *Sink) Close(ctx api.StreamContext) error { ctx.GetLogger().Info("Closing mqtt sink connector") if ms.cli != nil { - DetachConnection(ms.cli.GetClientId(), "") - ms.cli = nil + DetachConnection(ms.cli, ms.adconf.SelId, "") } return nil } diff --git a/internal/io/mqtt/sink_test.go b/internal/io/mqtt/sink_test.go index 83afccd077..da9a579c26 100644 --- a/internal/io/mqtt/sink_test.go +++ b/internal/io/mqtt/sink_test.go @@ -87,17 +87,19 @@ func TestSinkConfigure(t *testing.T) { { name: "Valid configuration with QoS 0 and no compression", input: map[string]interface{}{ - "topic": "testTopic3", - "qos": 0, - "retained": false, - "compression": "", - "server": "123", + "topic": "testTopic3", + "qos": 0, + "retained": false, + "compression": "", + "server": "123", + "connectionSelector": "mqtt.local", }, expectedAdConf: &AdConf{ Tpc: "testTopic3", Qos: 0, Retained: false, ResendTopic: "testTopic3", + SelId: "mqtt.local", }, }, { diff --git a/internal/io/mqtt/source.go b/internal/io/mqtt/source.go index 7c364f3b61..8c132bef8a 100644 --- a/internal/io/mqtt/source.go +++ b/internal/io/mqtt/source.go @@ -38,6 +38,7 @@ type SourceConnector struct { type Conf struct { Topic string `json:"datasource"` Qos int `json:"qos"` + SelId string `json:"connectionSelector"` } func (ms *SourceConnector) Provision(ctx api.StreamContext, props map[string]any) error { @@ -70,7 +71,7 @@ func (ms *SourceConnector) Ping(props map[string]interface{}) error { func (ms *SourceConnector) Connect(ctx api.StreamContext) error { ctx.GetLogger().Infof("Connecting to mqtt server") - cli, err := GetConnection(ctx, ms.props) + cli, err := GetConnection(ctx, ms.cfg.SelId, ms.props) ms.cli = cli return err } @@ -104,8 +105,7 @@ func (ms *SourceConnector) onMessage(ctx api.StreamContext, msg pahoMqtt.Message func (ms *SourceConnector) Close(ctx api.StreamContext) error { ctx.GetLogger().Infof("Closing mqtt source connector to topic %s.", ms.tpc) if ms.cli != nil { - DetachConnection(ms.cli.GetClientId(), ms.tpc) - ms.cli = nil + DetachConnection(ms.cli, ms.cfg.SelId, ms.tpc) } return nil }