Skip to content

Commit

Permalink
Merge branch 'main' into ivr_api
Browse files Browse the repository at this point in the history
  • Loading branch information
seeflood authored Aug 22, 2022
2 parents 8c42752 + c116079 commit ad7fc05
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 3 deletions.
112 changes: 110 additions & 2 deletions components/rpc/invoker/mosn/channel/xchannel.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"

"mosn.io/pkg/buffer"
"mosn.io/pkg/log"

"mosn.io/api"

common "mosn.io/layotto/components/pkg/common"
Expand Down Expand Up @@ -93,8 +97,104 @@ type xChannel struct {
pool *connPool
}

// Do is handle RPCRequest to RPCResponse
func (m *xChannel) Do(req *rpc.RPCRequest) (*rpc.RPCResponse, error) {
// InvokeWithTargetAddress send request to specific provider address
func (m *xChannel) InvokeWithTargetAddress(req *rpc.RPCRequest) (*rpc.RPCResponse, error) {
// 1. context.WithTimeout
timeout := time.Duration(req.Timeout) * time.Millisecond
ctx, cancel := context.WithTimeout(req.Ctx, timeout)
defer cancel()

// 2. get connection with specific address
conn, err := net.Dial("tcp", req.Header[rpc.TargetAddress][0])
if err != nil {
return nil, err
}
wc := &wrapConn{Conn: conn}

// 3. encode request
frame := m.proto.ToFrame(req)
buf, encErr := m.proto.Encode(req.Ctx, frame)
if encErr != nil {
return nil, common.Error(common.InternalCode, encErr.Error())
}

callChan := make(chan call, 1)
// 4. set timeout
deadline, _ := ctx.Deadline()
if err := conn.SetWriteDeadline(deadline); err != nil {
return nil, common.Error(common.UnavailebleCode, err.Error())
}

// 5. read package
go func() {
var err error
defer func() {
if err != nil {
callChan <- call{err: err}
}
wc.Close()
}()

wc.buf = buffer.NewIoBuffer(defaultBufSize)
for {
// read data from connection
n, readErr := wc.buf.ReadOnce(conn)
if readErr != nil {
err = readErr
if readErr == io.EOF {
log.DefaultLogger.Debugf("[runtime][rpc]direct conn read-loop err: %s", readErr.Error())
} else {
log.DefaultLogger.Errorf("[runtime][rpc]direct conn read-loop err: %s", readErr.Error())
}
}

if n > 0 {
iframe, decodeErr := m.proto.Decode(context.TODO(), wc.buf)
if decodeErr != nil {
err = decodeErr
log.DefaultLogger.Errorf("[runtime][rpc]direct conn decode frame err: %s", err)
break
}
frame, ok := iframe.(api.XRespFrame)
if frame == nil {
continue
}
if !ok {
err = errors.New("[runtime][rpc]xchannel type not XRespFrame")
log.DefaultLogger.Errorf("[runtime][rpc]direct conn decode frame err: %s", err)
break
}
callChan <- call{resp: frame}
return

}
if err != nil {
break
}
if wc.buf != nil && wc.buf.Len() == 0 && wc.buf.Cap() > maxBufSize {
wc.buf.Free()
wc.buf.Alloc(defaultBufSize)
}
}
}()

// 6. write packet
if _, err := conn.Write(buf.Bytes()); err != nil {
return nil, common.Error(common.UnavailebleCode, err.Error())
}

select {
case res := <-callChan:
if res.err != nil {
return nil, common.Error(common.UnavailebleCode, res.err.Error())
}
return m.proto.FromFrame(res.resp)
case <-ctx.Done():
return nil, common.Error(common.TimeoutCode, ErrTimeout.Error())
}
}

func (m *xChannel) Invoke(req *rpc.RPCRequest) (*rpc.RPCResponse, error) {
// 1. context.WithTimeout
timeout := time.Duration(req.Timeout) * time.Millisecond
ctx, cancel := context.WithTimeout(req.Ctx, timeout)
Expand Down Expand Up @@ -151,6 +251,14 @@ func (m *xChannel) Do(req *rpc.RPCRequest) (*rpc.RPCResponse, error) {
}
}

// Do is handle RPCRequest to RPCResponse
func (m *xChannel) Do(req *rpc.RPCRequest) (*rpc.RPCResponse, error) {
if _, ok := req.Header[rpc.TargetAddress]; ok && len(req.Header[rpc.TargetAddress]) > 0 {
return m.InvokeWithTargetAddress(req)
}
return m.Invoke(req)
}

// removeCall is delete xstate.calls by id
func (m *xChannel) removeCall(xstate *xstate, id uint32) {
xstate.mu.Lock()
Expand Down
9 changes: 8 additions & 1 deletion components/rpc/invoker/mosn/mosninvoker.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"errors"
"fmt"
"strconv"

// bridge to mosn
_ "mosn.io/mosn/pkg/filter/network/proxy"
Expand Down Expand Up @@ -93,7 +94,13 @@ func (m *mosnInvoker) Invoke(ctx context.Context, req *rpc.RPCRequest) (resp *rp

// 1. validate request
if req.Timeout == 0 {
req.Timeout = 3000
req.Timeout = rpc.DefaultRequestTimeoutMs
if ts, ok := req.Header[rpc.RequestTimeoutMs]; ok && len(ts) > 0 {
t, err := strconv.Atoi(ts[0])
if err == nil && t != 0 {
req.Timeout = int32(t)
}
}
}
req.Ctx = ctx
log.DefaultLogger.Debugf("[runtime][rpc]request %+v", req)
Expand Down
17 changes: 17 additions & 0 deletions components/rpc/invoker/mosn/mosninvoker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,27 @@ func Test_mosnInvoker_Invoke(t *testing.T) {
Timeout: 100,
Method: "Hello",
Data: []byte("hello"),
Header: map[string][]string{},
}
rsp, err := invoker.Invoke(context.Background(), req)
assert.Nil(t, err)
assert.Equal(t, "hello world!", string(rsp.Data))

req.Header[rpc.RequestTimeoutMs] = []string{"0"}
req.Timeout = 0
rsp, err = invoker.Invoke(context.Background(), req)
assert.Nil(t, err)
assert.Equal(t, "hello world!", string(rsp.Data))

assert.Equal(t, int32(3000), req.Timeout)

req.Header[rpc.RequestTimeoutMs] = []string{"100000"}
req.Timeout = 0
rsp, err = invoker.Invoke(context.Background(), req)
assert.Nil(t, err)
assert.Equal(t, "hello world!", string(rsp.Data))

assert.Equal(t, int32(100000), req.Timeout)
})

t.Run("panic", func(t *testing.T) {
Expand Down
9 changes: 9 additions & 0 deletions components/rpc/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ import (
"strings"
)

const (
TargetAddress = "rpc_target_address"
RequestTimeoutMs = "rpc_request_timeout"
)

const (
DefaultRequestTimeoutMs = 3000
)

// RPCHeader is storage header info
type RPCHeader map[string][]string

Expand Down

0 comments on commit ad7fc05

Please sign in to comment.