Skip to content

Commit

Permalink
fix: fix cast security issue (#3243)
Browse files Browse the repository at this point in the history
Signed-off-by: Song Gao <[email protected]>
  • Loading branch information
Yisaer authored Sep 29, 2024
1 parent 62299ee commit 58c15a6
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 9 deletions.
4 changes: 4 additions & 0 deletions internal/binder/function/funcs_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/lf-edge/ekuiper/contract/v2/api"

"github.com/lf-edge/ekuiper/v2/internal/conf"
"github.com/lf-edge/ekuiper/v2/pkg/ast"
"github.com/lf-edge/ekuiper/v2/pkg/cast"
)
Expand Down Expand Up @@ -630,6 +631,9 @@ func conv(str string, fromBase, toBase int64) (res string, isNull bool, err erro
val = -val
}

if val > math.MaxInt64 {
conf.Log.Warnf("value %d is out of int64 range", val)
}
if int64(val) < 0 {
negative = true
} else {
Expand Down
4 changes: 4 additions & 0 deletions internal/binder/function/funcs_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,10 @@ func TestConvFunc(t *testing.T) {
}
for _, c := range cases {
got, _ := fConv.exec(fctx, []interface{}{c.args[0], c.args[1], c.args[2]})
if c.getErr {
require.Error(t, got.(error))
continue
}
if got != c.expected {
t.Errorf("%s:Expected %s, but got %s", c.args[0], c.expected, got)
}
Expand Down
14 changes: 14 additions & 0 deletions internal/converter/protobuf/fieldConverterSingleton.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ package protobuf

import (
"fmt"
"math"

// TODO: replace with `google.golang.org/protobuf/proto` pkg.
"github.com/golang/protobuf/proto" //nolint:staticcheck
dpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/jhump/protoreflect/desc" //nolint:staticcheck
"github.com/jhump/protoreflect/dynamic" //nolint:staticcheck

"github.com/lf-edge/ekuiper/v2/internal/conf"
"github.com/lf-edge/ekuiper/v2/pkg/cast"
)

Expand Down Expand Up @@ -111,6 +113,9 @@ func (fc *FieldConverter) EncodeField(field *desc.FieldDescriptor, v interface{}
if err != nil {
return 0, nil
} else {
if r > math.MaxInt32 {
conf.Log.Warnf("value %d is out of int32 range", r)
}
return int32(r), nil
}
}, "int", cast.CONVERT_SAMEKIND)
Expand All @@ -122,6 +127,9 @@ func (fc *FieldConverter) EncodeField(field *desc.FieldDescriptor, v interface{}
if err != nil {
return 0, nil
} else {
if r > math.MaxUint32 {
conf.Log.Warnf("value %d is out of uint32 range", v)
}
return uint32(r), nil
}
}, "uint", cast.CONVERT_SAMEKIND)
Expand Down Expand Up @@ -174,6 +182,9 @@ func (fc *FieldConverter) encodeSingleField(field *desc.FieldDescriptor, v inter
case dpb.FieldDescriptorProto_TYPE_INT32, dpb.FieldDescriptorProto_TYPE_SFIXED32, dpb.FieldDescriptorProto_TYPE_SINT32, dpb.FieldDescriptorProto_TYPE_ENUM:
r, err := cast.ToInt(v, cast.CONVERT_SAMEKIND)
if err == nil {
if r > math.MaxInt32 {
conf.Log.Warnf("value %d is out of int32 range", v)
}
return int32(r), nil
} else {
return nil, fmt.Errorf("invalid type for int type field '%s': %v", fn, err)
Expand All @@ -188,6 +199,9 @@ func (fc *FieldConverter) encodeSingleField(field *desc.FieldDescriptor, v inter
case dpb.FieldDescriptorProto_TYPE_FIXED32, dpb.FieldDescriptorProto_TYPE_UINT32:
r, err := cast.ToUint64(v, cast.CONVERT_SAMEKIND)
if err == nil {
if r > math.MaxUint32 {
conf.Log.Warnf("value %d is out of uint32 range", v)
}
return uint32(r), nil
} else {
return nil, fmt.Errorf("invalid type for uint type field '%s': %v", fn, err)
Expand Down
2 changes: 1 addition & 1 deletion internal/server/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func getTraceIDByRuleID(w http.ResponseWriter, r *http.Request) {
if err != nil {
limit = 0
}
root, err := tracer.GetTraceIDListByRuleID(id, int(limit))
root, err := tracer.GetTraceIDListByRuleID(id, limit)
if err != nil {
handleError(w, err, "", logger)
return
Expand Down
16 changes: 16 additions & 0 deletions pkg/cast/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cast
import (
"encoding/base64"
"fmt"
"math"
"reflect"
"strconv"
"sync"
Expand Down Expand Up @@ -187,6 +188,9 @@ func ToInt8(input interface{}, sn Strictness) (int8, error) {
if sn == CONVERT_ALL {
v, err := strconv.ParseInt(s, 0, 0)
if err == nil {
if v > math.MaxInt8 {
return 0, fmt.Errorf("value %d is out of int8 range", v)
}
return int8(v), nil
}
}
Expand Down Expand Up @@ -239,6 +243,9 @@ func ToInt16(input interface{}, sn Strictness) (int16, error) {
if sn == CONVERT_ALL {
v, err := strconv.ParseInt(s, 0, 0)
if err == nil {
if v > math.MaxInt16 {
return 0, fmt.Errorf("value %d is out of int32 range", v)
}
return int16(v), nil
}
}
Expand Down Expand Up @@ -291,6 +298,9 @@ func ToInt32(input interface{}, sn Strictness) (int32, error) {
if sn == CONVERT_ALL {
v, err := strconv.ParseInt(s, 0, 0)
if err == nil {
if v > math.MaxInt32 {
return 0, fmt.Errorf("value %d is out of int32 range", v)
}
return int32(v), nil
}
}
Expand Down Expand Up @@ -577,6 +587,9 @@ func ToUint8(i interface{}, sn Strictness) (uint8, error) {
if sn == CONVERT_ALL {
v, err := strconv.ParseUint(s, 0, 64)
if err == nil {
if v > math.MaxUint8 {
return 0, fmt.Errorf("value %d is out of uint16 range", v)
}
return uint8(v), nil
}
}
Expand Down Expand Up @@ -650,6 +663,9 @@ func ToUint16(i interface{}, sn Strictness) (uint16, error) {
if sn == CONVERT_ALL {
v, err := strconv.ParseUint(s, 0, 64)
if err == nil {
if v > math.MaxUint16 {
return 0, fmt.Errorf("value %d is out of uint16 range", v)
}
return uint16(v), nil
}
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/tracer/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ func (l *SpanExporter) GetTraceById(traceID string) (*LocalSpan, error) {
return l.spanStorage.GetTraceById(traceID)
}

func (l *SpanExporter) GetTraceByRuleID(ruleID string, limit int) ([]string, error) {
func (l *SpanExporter) GetTraceByRuleID(ruleID string, limit int64) ([]string, error) {
return l.spanStorage.GetTraceByRuleID(ruleID, limit)
}

type LocalSpanStorage interface {
SaveSpan(span sdktrace.ReadOnlySpan) error
GetTraceById(traceID string) (*LocalSpan, error)
GetTraceByRuleID(ruleID string, limit int) ([]string, error)
GetTraceByRuleID(ruleID string, limit int64) ([]string, error)
}

type LocalSpanMemoryStorage struct {
Expand Down Expand Up @@ -169,15 +169,15 @@ func (l *LocalSpanMemoryStorage) GetTraceById(traceID string) (*LocalSpan, error
return rootSpan, nil
}

func (l *LocalSpanMemoryStorage) GetTraceByRuleID(ruleID string, limit int) ([]string, error) {
func (l *LocalSpanMemoryStorage) GetTraceByRuleID(ruleID string, limit int64) ([]string, error) {
l.RLock()
defer l.RUnlock()
traceMap := l.ruleTraceMap[ruleID]
r := make([]string, 0)
if limit < 1 {
limit = len(traceMap)
limit = int64(len(traceMap))
}
count := 0
count := int64(0)
for traceID := range traceMap {
r = append(r, traceID)
count++
Expand Down Expand Up @@ -291,7 +291,7 @@ func (s *sqlSpanStorage) GetTraceById(traceID string) (*LocalSpan, error) {
return s.loadTraceByTraceID(traceID)
}

func (s *sqlSpanStorage) GetTraceByRuleID(ruleID string, limit int) ([]string, error) {
func (s *sqlSpanStorage) GetTraceByRuleID(ruleID string, limit int64) ([]string, error) {
return s.loadTraceByRuleID(ruleID)
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/tracer/tracer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (g *GlobalTracerManager) GetTraceById(traceID string) (root *LocalSpan, err
return g.SpanExporter.GetTraceById(traceID)
}

func (g *GlobalTracerManager) GetTraceByRuleID(ruleID string, limit int) ([]string, error) {
func (g *GlobalTracerManager) GetTraceByRuleID(ruleID string, limit int64) ([]string, error) {
g.RLock()
defer g.RUnlock()
return g.SpanExporter.GetTraceByRuleID(ruleID, limit)
Expand Down Expand Up @@ -145,7 +145,7 @@ func loadTracerConfig() (*TracerConfig, error) {
return tracerConfig, nil
}

func GetTraceIDListByRuleID(ruleID string, limit int) ([]string, error) {
func GetTraceIDListByRuleID(ruleID string, limit int64) ([]string, error) {
globalTracerManager.InitIfNot()
return globalTracerManager.GetTraceByRuleID(ruleID, limit)
}

0 comments on commit 58c15a6

Please sign in to comment.