Skip to content

Commit

Permalink
feat: 重构thrift & kitex相关逻辑
Browse files Browse the repository at this point in the history
  • Loading branch information
fanhaodong.516 committed Sep 24, 2024
1 parent 9b8b2a6 commit 3b510ac
Show file tree
Hide file tree
Showing 50 changed files with 11,248 additions and 331 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/devtool.txt

This file was deleted.

3 changes: 2 additions & 1 deletion command/codec/thrift.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ func newThriftCodecCmd() (*cobra.Command, error) {

handlerMessage := func(payload []byte) error {
buffer := bufio.NewReader(bytes.NewBuffer(payload))
protocol, err := thrift_codec.GetProtocol(ctx, buffer)
protocol, metaInfo, err := thrift_codec.GetProtocol(ctx, buffer)
if err != nil {
return fmt.Errorf(`decode message find err: %v`, err)
}
data, err := thrift_codec.DecodeMessage(ctx, thrift_codec.NewTProtocol(buffer, protocol))
if err != nil {
return fmt.Errorf(`decode message find err(proto=%s): %v`, protocol, err)
}
data.MetaInfo = metaInfo
data.Protocol = protocol
_, _ = os.Stdout.WriteString(utils.ToJson(data))
return nil
Expand Down
2 changes: 1 addition & 1 deletion command/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ type HexoConfig struct {
}

type CurlConfig struct {
NewClient func(ctx context.Context, request *rpc.Request, idl *rpc.IDLInfo) (rpc.Client, error)
NewThriftClient func(ctx context.Context, request *rpc.Request) (*rpc.ThriftClient, error)
}
48 changes: 24 additions & 24 deletions command/curl/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/spf13/cobra"

"github.com/anthony-dong/golang/command"
"github.com/anthony-dong/golang/pkg/idl"
"github.com/anthony-dong/golang/pkg/logs"
"github.com/anthony-dong/golang/pkg/rpc"
"github.com/anthony-dong/golang/pkg/utils"
Expand All @@ -19,7 +18,7 @@ func NewCurlCommand(configProvider func() *command.CurlConfig) (*cobra.Command,
reqHeader := make([]string, 0)
listMethods := false
showExample := false
idlInfo := rpc.IDLInfo{}
idlConfig := rpc.IDLConfig{}
timeout := time.Second * 180
enableModifyReq := false
cmd := &cobra.Command{
Expand All @@ -29,64 +28,65 @@ func NewCurlCommand(configProvider func() *command.CurlConfig) (*cobra.Command,
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
var (
client rpc.Client
rpcRequest *rpc.Request
err error
client *rpc.ThriftClient
req *rpc.Request
err error
)
rpcRequest, err = rpc.NewRpcRequest(reqUrl, reqHeader, reqBody)
req, err = rpc.NewRpcRequest(reqUrl, reqHeader, reqBody)
if err != nil {
return err
}
rpcRequest.Timeout = utils.NewJsonDuration(timeout)
rpcRequest.EnableModifyRequest = enableModifyReq
req.Timeout = utils.NewJsonDuration(timeout)
req.EnableModifyRequest = enableModifyReq
req.IDLConfig = &idlConfig
if !showExample && !listMethods {
logs.CtxInfo(ctx, "rpc request: %s", rpcRequest.String())
logs.CtxInfo(ctx, "rpc request info: %s", utils.ToString(req.BasicInfo()))
}
config := configProvider()
if config != nil && config.NewClient != nil {
if client, err = config.NewClient(ctx, rpcRequest, &idlInfo); err != nil {
if config != nil && config.NewThriftClient != nil {
if client, err = config.NewThriftClient(ctx, req); err != nil {
return err
}
} else {
if idlInfo.Main == "" {
return fmt.Errorf(`new local idl find err: not found main idl: %q`, idlInfo.Main)
if idlConfig.Main == "" {
return fmt.Errorf(`new local idl find err: not found main idl: %q`, idlConfig.Main)
}
if client, err = rpc.NewThriftClient(rpc.NewLocalIDLProvider(map[string]string{req.ServiceName: idlConfig.Main})); err != nil {
return err
}
client = rpc.NewThriftClient(idl.NewDescriptorProvider(idl.NewMemoryIDLProvider(idlInfo.Main)))
}

if listMethods {
allMethods, err := client.ListMethods(ctx)
allMethods, err := client.ListMethods(ctx, req.ServiceName, req.IDLConfig)
if err != nil {
return fmt.Errorf(`list methods find err: %v`, err)
}
logs.CtxInfo(ctx, "methods:\n%s", utils.ToJson(allMethods, true))
return nil
}

if showExample {
jsonExample, err := client.GetExampleCode(ctx, &rpc.Method{RPCMethod: rpcRequest.RPCMethod})
jsonExample, err := client.GetExampleCode(ctx, req.ServiceName, req.IDLConfig, req.ServiceName)
if err != nil {
return fmt.Errorf(`new request example find err: %v`, err)
}
logs.CtxInfo(ctx, "new request example\n%s", jsonExample.Body)
logs.CtxInfo(ctx, "new request example\n%s", jsonExample)
return nil
}

rpcResponse, err := client.Do(ctx, rpcRequest)
resp, err := client.Do(ctx, req)
if err != nil {
return fmt.Errorf(`do rpc request find err: %v`, err)
}
flag := "success"
if rpcResponse.IsError {
if resp.IsError {
flag = "error"
}
logs.CtxInfo(ctx, "rpc response %s: %s", flag, rpcResponse.String())
logs.CtxInfo(ctx, "rpc response %s:\n%s", flag, utils.PrettyJsonBytes(resp.Body))
return nil
},
}
cmd.Flags().StringVar(&reqUrl, "url", "", "The request url")
cmd.Flags().StringVar(&idlInfo.Main, "idl", "", "The main IDL local path")
cmd.Flags().StringVar(&idlInfo.Branch, "branch", "", "The Remote IDL branch/version/commit(if supports it)")
cmd.Flags().StringVar(&idlConfig.Main, "idl", "", "The main IDL local path")
cmd.Flags().StringVar(&idlConfig.Branch, "branch", "", "The Remote IDL branch/version/commit(if supports it)")
cmd.Flags().StringSliceVarP(&reqHeader, "header", "H", []string{}, "The request header")
cmd.Flags().StringVar(&reqBody, "data", "", "The request body")
cmd.Flags().BoolVar(&listMethods, "methods", false, "List all the methods")
Expand Down
6 changes: 4 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ require (
gopkg.in/yaml.v3 v3.0.1
)

require github.com/golang/protobuf v1.5.2
require (
github.com/cloudwego/gopkg v0.1.2
github.com/golang/protobuf v1.5.2
)

require (
github.com/bytedance/sonic v1.12.2 // indirect
Expand All @@ -44,7 +47,6 @@ require (
github.com/bytedance/sonic/loader v0.2.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/dynamicgo v0.4.0 // indirect
github.com/cloudwego/gopkg v0.1.2 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/cloudwego/localsession v0.0.2 // indirect
github.com/cloudwego/runtimex v0.1.0 // indirect
Expand Down
203 changes: 203 additions & 0 deletions pkg/codec/thrift_codec/encode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package thrift_codec

import (
"encoding/json"
"fmt"
"strconv"

"github.com/anthony-dong/golang/pkg/utils"
"github.com/apache/thrift/lib/go/thrift"
"github.com/cloudwego/kitex/pkg/generic/descriptor"
"github.com/iancoleman/orderedmap"
)

func EncodeMessage(oprot thrift.TProtocol, desc *descriptor.TypeDescriptor, data interface{}) error {
return encodeThriftType(oprot, "", desc, data)
}

func EncodeReply(oprot thrift.TProtocol, function *descriptor.FunctionDescriptor, seq int32, data interface{}) error {
if err := oprot.WriteMessageBegin(function.Name, thrift.REPLY, seq); err != nil {
return err
}
if err := EncodeMessage(oprot, function.Response, data); err != nil {
return err
}
if err := oprot.WriteMessageEnd(); err != nil {
return err
}
return nil
}

func safeToOrderMap(data interface{}) (*orderedmap.OrderedMap, bool) {
switch v := data.(type) {
case map[string]interface{}:
kv := orderedmap.New()
for key, value := range v {
kv.Set(key, value)
}
return kv, true
case *orderedmap.OrderedMap:
return v, true
case orderedmap.OrderedMap:
return &v, true
default:
return nil, false
}
}

func toThriftType(p descriptor.Type) thrift.TType {
return thrift.TType(p.ToThriftTType())
}

func encodeThriftType(iprot thrift.TProtocol, fieldName string, tType *descriptor.TypeDescriptor, data interface{}) error {
if data == nil {
return nil
}
switch tType.Type {
case descriptor.STRUCT:
kv, isOK := safeToOrderMap(data)
if !isOK {
return fmt.Errorf(`invalid type: %T`, data)
}
if err := iprot.WriteStructBegin(fieldName); err != nil {
return err
}
for _, k := range kv.Keys() {
v, _ := kv.Get(k)
desc := tType.Struct.FieldsByName[k]
if desc == nil {
return fmt.Errorf(`not found field name %s`, k)
}
fieldType := toThriftType(desc.Type.Type)
// if is void resp. skip write field
if fieldType == thrift.VOID {
continue
}
if err := iprot.WriteFieldBegin(desc.Name, fieldType, int16(desc.ID)); err != nil {
return err
}
if err := encodeThriftType(iprot, desc.Name, desc.Type, v); err != nil {
return err
}
if err := iprot.WriteFieldEnd(); err != nil {
return err
}
}
if err := iprot.WriteFieldStop(); err != nil {
return err
}
if err := iprot.WriteStructEnd(); err != nil {
return err
}
case descriptor.LIST, descriptor.SET:
datas, isOK := data.([]interface{})
if !isOK {
return fmt.Errorf(`invalid type: %T`, data)
}
if err := iprot.WriteListBegin(toThriftType(tType.Elem.Type), len(datas)); err != nil {
return err
}
for _, v := range datas {
if err := encodeThriftType(iprot, fieldName, tType.Elem, v); err != nil {
return err
}
}
if err := iprot.WriteListEnd(); err != nil {
return err
}
case descriptor.MAP:
kv, isOK := data.(map[string]interface{})
if !isOK {
return fmt.Errorf(`invalid type: %T`, data)
}
if err := iprot.WriteMapBegin(toThriftType(tType.Key.Type), toThriftType(tType.Elem.Type), len(kv)); err != nil {
return err
}
for k, v := range kv {
if err := encodeThriftType(iprot, fieldName, tType.Key, k); err != nil {
return err
}
if err := encodeThriftType(iprot, fieldName, tType.Elem, v); err != nil {
return err
}
}
if err := iprot.WriteMapEnd(); err != nil {
return err
}
case descriptor.STRING:
if tType.Name == "binary" {
switch v := data.(type) {
case []byte:
if err := iprot.WriteBinary(v); err != nil {
return err
}
default:
if err := iprot.WriteBinary(utils.String2Bytes(utils.ToString(data))); err != nil {
return err
}
}
} else {
if err := iprot.WriteString(utils.ToString(data)); err != nil {
return err
}
}
case descriptor.I08:
i8, err := toInt64(data, 8)
if err != nil {
return err
}
if err := iprot.WriteByte(int8(i8)); err != nil {
return err
}
case descriptor.I16:
i16, err := toInt64(data, 16)
if err != nil {
return err
}
if err := iprot.WriteI16(int16(i16)); err != nil {
return err
}
case descriptor.I32:
i32, err := toInt64(data, 32)
if err != nil {
return err
}
if err := iprot.WriteI32(int32(i32)); err != nil {
return err
}
case descriptor.I64:
i64, err := toInt64(data, 64)
if err != nil {
return err
}
if err := iprot.WriteI64(i64); err != nil {
return err
}
case descriptor.BOOL:
value := utils.ToString(data)
parseBool, err := strconv.ParseBool(value)
if err != nil {
return err
}
if err := iprot.WriteBool(parseBool); err != nil {
return err
}
case descriptor.VOID:
return nil
default:
return fmt.Errorf(`unsupported type: %s`, tType.Type)
}
return nil
}

func toInt64(data interface{}, size int) (int64, error) {
switch v := data.(type) {
case json.Number:
return v.Int64()
case float64:
return int64(v), nil
case int64:
return v, nil
}
return strconv.ParseInt(fmt.Sprintf("%v", data), 10, size)
}
Loading

0 comments on commit 3b510ac

Please sign in to comment.