Skip to content

Commit

Permalink
support sticky connection
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingSinger committed Dec 9, 2019
1 parent 78e7ed0 commit c5ad873
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 9 deletions.
35 changes: 31 additions & 4 deletions cluster/cluster_impl/base_cluster_invoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type baseClusterInvoker struct {
directory cluster.Directory
availablecheck bool
destroyed *atomic.Bool
stickyInvoker protocol.Invoker
}

func newBaseClusterInvoker(directory cluster.Directory) baseClusterInvoker {
Expand Down Expand Up @@ -83,15 +84,42 @@ func (invoker *baseClusterInvoker) checkWhetherDestroyed() error {
}

func (invoker *baseClusterInvoker) doSelect(lb cluster.LoadBalance, invocation protocol.Invocation, invokers []protocol.Invoker, invoked []protocol.Invoker) protocol.Invoker {
//todo:sticky connect

var selectedInvoker protocol.Invoker
url := invokers[0].GetUrl()
sticky := url.GetParam(constant.STICKY_KEY, "false")
//Get the service method sticky config if have
if v := url.GetMethodParam(invocation.MethodName(), constant.STICKY_KEY, sticky); len(v) != 0 {
sticky = v
}

if invoker.stickyInvoker != nil && !isInvoked(invoker.stickyInvoker, invokers) {
invoker.stickyInvoker = nil
}

if sticky == "true" && invoker.stickyInvoker != nil && (invoked == nil || !isInvoked(invoker.stickyInvoker, invoked)) {
return invoker.stickyInvoker
}

selectedInvoker = invoker.doSelectInvoker(lb, invocation, invokers, invoked)

if sticky == "true" {
invoker.stickyInvoker = selectedInvoker
}
return selectedInvoker

}

func (invoker *baseClusterInvoker) doSelectInvoker(lb cluster.LoadBalance, invocation protocol.Invocation, invokers []protocol.Invoker, invoked []protocol.Invoker) protocol.Invoker {
if len(invokers) == 1 {
return invokers[0]
}

selectedInvoker := lb.Select(invokers, invocation)

//judge to if the selectedInvoker is invoked

if !selectedInvoker.IsAvailable() || !invoker.availablecheck || isInvoked(selectedInvoker, invoked) {
if (!selectedInvoker.IsAvailable() && invoker.availablecheck) || isInvoked(selectedInvoker, invoked) {
// do reselect
var reslectInvokers []protocol.Invoker

Expand All @@ -106,13 +134,12 @@ func (invoker *baseClusterInvoker) doSelect(lb cluster.LoadBalance, invocation p
}

if len(reslectInvokers) > 0 {
return lb.Select(reslectInvokers, invocation)
selectedInvoker = lb.Select(reslectInvokers, invocation)
} else {
return nil
}
}
return selectedInvoker

}

func isInvoked(selectedInvoker protocol.Invoker, invoked []protocol.Invoker) bool {
Expand Down
63 changes: 63 additions & 0 deletions cluster/cluster_impl/base_cluster_invoker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cluster_impl

import (
"context"
"fmt"
"testing"
)

import (
"github.com/stretchr/testify/assert"
)

import (
"github.com/apache/dubbo-go/cluster/loadbalance"
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/invocation"
)

func Test_StickyNormal(t *testing.T) {
invokers := []protocol.Invoker{}
for i := 0; i < 10; i++ {
url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
url.SetParam("sticky", "true")
invokers = append(invokers, NewMockInvoker(url, 1))
}
base := &baseClusterInvoker{}
invoked := []protocol.Invoker{}
result := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked)
result1 := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked)
assert.Equal(t, result, result1)
}
func Test_StickyNormalWhenError(t *testing.T) {
invokers := []protocol.Invoker{}
for i := 0; i < 10; i++ {
url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/com.ikurento.user.UserProvider", i))
url.SetParam("sticky", "true")
invokers = append(invokers, NewMockInvoker(url, 1))
}
base := &baseClusterInvoker{}
invoked := []protocol.Invoker{}
result := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked)
invoked = append(invoked, result)
result1 := base.doSelect(loadbalance.NewRandomLoadBalance(), invocation.NewRPCInvocation("getUser", nil, nil), invokers, invoked)
assert.NotEqual(t, result, result1)
}
2 changes: 1 addition & 1 deletion cluster/cluster_impl/failback_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func Test_FailbackSuceess(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := registerFailback(t, invoker).(*failbackClusterInvoker)

invoker.EXPECT().GetUrl().Return(failbackUrl).Times(1)
invoker.EXPECT().GetUrl().Return(failbackUrl).AnyTimes()

mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}
invoker.EXPECT().Invoke(gomock.Any()).Return(mockResult)
Expand Down
4 changes: 2 additions & 2 deletions cluster/cluster_impl/failfast_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func Test_FailfastInvokeSuccess(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := registerFailfast(t, invoker)

invoker.EXPECT().GetUrl().Return(failfastUrl)
invoker.EXPECT().GetUrl().Return(failfastUrl).AnyTimes()

mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}

Expand All @@ -84,7 +84,7 @@ func Test_FailfastInvokeFail(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := registerFailfast(t, invoker)

invoker.EXPECT().GetUrl().Return(failfastUrl)
invoker.EXPECT().GetUrl().Return(failfastUrl).AnyTimes()

mockResult := &protocol.RPCResult{Err: perrors.New("error")}

Expand Down
4 changes: 2 additions & 2 deletions cluster/cluster_impl/failsafe_cluster_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func Test_FailSafeInvokeSuccess(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := register_failsafe(t, invoker)

invoker.EXPECT().GetUrl().Return(failsafeUrl)
invoker.EXPECT().GetUrl().Return(failsafeUrl).AnyTimes()

mockResult := &protocol.RPCResult{Rest: rest{tried: 0, success: true}}

Expand All @@ -83,7 +83,7 @@ func Test_FailSafeInvokeFail(t *testing.T) {
invoker := mock.NewMockInvoker(ctrl)
clusterInvoker := register_failsafe(t, invoker)

invoker.EXPECT().GetUrl().Return(failsafeUrl)
invoker.EXPECT().GetUrl().Return(failsafeUrl).AnyTimes()

mockResult := &protocol.RPCResult{Err: perrors.New("error")}

Expand Down
1 change: 1 addition & 0 deletions common/constant/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ const (
WEIGHT_KEY = "weight"
WARMUP_KEY = "warmup"
RETRIES_KEY = "retries"
STICKY_KEY = "sticky"
BEAN_NAME = "bean.name"
FAIL_BACK_TASKS_KEY = "failbacktasks"
FORKS_KEY = "forks"
Expand Down
1 change: 1 addition & 0 deletions config/method_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type MethodConfig struct {
TpsLimitStrategy string `yaml:"tps.limit.strategy" json:"tps.limit.strategy,omitempty" property:"tps.limit.strategy"`
ExecuteLimit string `yaml:"execute.limit" json:"execute.limit,omitempty" property:"execute.limit"`
ExecuteLimitRejectedHandler string `yaml:"execute.limit.rejected.handler" json:"execute.limit.rejected.handler,omitempty" property:"execute.limit.rejected.handler"`
Sticky bool `yaml:"sticky" json:"sticky,omitempty" property:"sticky"`
}

func (c *MethodConfig) Prefix() string {
Expand Down
3 changes: 3 additions & 0 deletions config/reference_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type ReferenceConfig struct {
invoker protocol.Invoker
urls []*common.URL
Generic bool `yaml:"generic" json:"generic,omitempty" property:"generic"`
Sticky bool `yaml:"sticky" json:"sticky,omitempty" property:"sticky"`
}

func (c *ReferenceConfig) Prefix() string {
Expand Down Expand Up @@ -170,6 +171,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values {
urlMap.Set(constant.ROLE_KEY, strconv.Itoa(common.CONSUMER))
//getty invoke async or sync
urlMap.Set(constant.ASYNC_KEY, strconv.FormatBool(refconfig.async))
urlMap.Set(constant.STICKY_KEY, strconv.FormatBool(refconfig.Sticky))

//application info
urlMap.Set(constant.APPLICATION_KEY, consumerConfig.ApplicationConfig.Name)
Expand All @@ -190,6 +192,7 @@ func (refconfig *ReferenceConfig) getUrlMap() url.Values {
for _, v := range refconfig.Methods {
urlMap.Set("methods."+v.Name+"."+constant.LOADBALANCE_KEY, v.Loadbalance)
urlMap.Set("methods."+v.Name+"."+constant.RETRIES_KEY, v.Retries)
urlMap.Set("methods."+v.Name+"."+constant.STICKY_KEY, strconv.FormatBool(v.Sticky))
}

return urlMap
Expand Down
20 changes: 20 additions & 0 deletions config/reference_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func doInitConsumer() {
"serviceid": "soa.mock",
"forks": "5",
},
Sticky: false,
Registry: "shanghai_reg1,shanghai_reg2,hangzhou_reg1,hangzhou_reg2",
InterfaceName: "com.MockService",
Protocol: "mock",
Expand All @@ -103,6 +104,7 @@ func doInitConsumer() {
Name: "GetUser1",
Retries: "2",
Loadbalance: "random",
Sticky: true,
},
},
},
Expand Down Expand Up @@ -254,6 +256,24 @@ func Test_Forking(t *testing.T) {
consumerConfig = nil
}

func Test_Sticky(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
extension.SetProtocol("registry", GetProtocol)
m := consumerConfig.References["MockService"]
m.Url = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000"

reference := consumerConfig.References["MockService"]
reference.Refer()
referenceSticky := reference.invoker.GetUrl().GetParam(constant.STICKY_KEY, "false")
assert.Equal(t, "false", referenceSticky)

method0StickKey := reference.invoker.GetUrl().GetMethodParam(reference.Methods[0].Name, constant.STICKY_KEY, "false")
assert.Equal(t, "false", method0StickKey)
method1StickKey := reference.invoker.GetUrl().GetMethodParam(reference.Methods[1].Name, constant.STICKY_KEY, "false")
assert.Equal(t, "true", method1StickKey)
}

func GetProtocol() protocol.Protocol {
if regProtocol != nil {
return regProtocol
Expand Down

0 comments on commit c5ad873

Please sign in to comment.