From eb375d5c6f6807eb46789bcfcaf10f5e860c3434 Mon Sep 17 00:00:00 2001 From: Huan Du Date: Tue, 5 Nov 2024 14:24:10 +0800 Subject: [PATCH] fix #178: avoid stackoverflow when Cond is misused --- args.go | 12 ++++++++---- args_test.go | 2 +- cond.go | 14 +++++++++++++- cond_test.go | 13 +++++++++++++ 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/args.go b/args.go index 36f8b34..8adf28c 100644 --- a/args.go +++ b/args.go @@ -16,6 +16,7 @@ type Args struct { // The default flavor used by `Args#Compile` Flavor Flavor + indexBase int argValues []interface{} namedArgs map[string]int sqlNamedArgs map[string]int @@ -47,7 +48,7 @@ func (args *Args) Add(arg interface{}) string { } func (args *Args) add(arg interface{}) int { - idx := len(args.argValues) + idx := len(args.argValues) + args.indexBase switch a := arg.(type) { case sql.NamedArg: @@ -164,7 +165,7 @@ func (args *Args) compileNamed(ctx *argsCompileContext, format string) string { format = format[i+1:] if p, ok := args.namedArgs[name]; ok { - format, _ = args.compileSuccessive(ctx, format, p) + format, _ = args.compileSuccessive(ctx, format, p-args.indexBase) } return format @@ -181,14 +182,17 @@ func (args *Args) compileDigits(ctx *argsCompileContext, format string, offset i format = format[i:] if pointer, err := strconv.Atoi(digits); err == nil { - return args.compileSuccessive(ctx, format, pointer) + return args.compileSuccessive(ctx, format, pointer-args.indexBase) } return format, offset } func (args *Args) compileSuccessive(ctx *argsCompileContext, format string, offset int) (string, int) { - if offset >= len(args.argValues) { + if offset < 0 || offset >= len(args.argValues) { + ctx.WriteString("/* INVALID ARG $") + ctx.WriteString(strconv.Itoa(offset)) + ctx.WriteString(" */") return format, offset } diff --git a/args_test.go b/args_test.go index 88a58f9..3212741 100644 --- a/args_test.go +++ b/args_test.go @@ -23,7 +23,7 @@ func TestArgs(t *testing.T) { cases := map[string][]interface{}{ "abc ? def\n[123]": {"abc $? def", 123}, "abc ? def\n[456]": {"abc $0 def", 456}, - "abc def\n[]": {"abc $1 def", 123}, + "abc /* INVALID ARG $1 */ def\n[]": {"abc $1 def", 123}, "abc def \n[]": {"abc ${unknown} def ", 123}, "abc $ def\n[]": {"abc $$ def", 123}, "abcdef$\n[]": {"abcdef$", 123}, diff --git a/cond.go b/cond.go index 2396b6d..01754a0 100644 --- a/cond.go +++ b/cond.go @@ -11,6 +11,8 @@ const ( opNOT = "NOT " ) +const minIndexBase = 256 + // Cond provides several helper methods to build conditions. type Cond struct { Args *Args @@ -19,7 +21,17 @@ type Cond struct { // NewCond returns a new Cond. func NewCond() *Cond { return &Cond{ - Args: &Args{}, + Args: &Args{ + // Based on the discussion in #174, users may call this method to create + // `Cond` for building various conditions, which is a misuse, but we + // cannot completely prevent this error. To facilitate users in + // identifying the issue when they make mistakes and to avoid + // unexpected stackoverflows, the base index for `Args` is + // deliberately set to a larger non-zero value here. This can + // significantly reduce the likelihood of issues and allows for + // timely error notification to users. + indexBase: minIndexBase, + }, } } diff --git a/cond_test.go b/cond_test.go index 2587855..ec2b4e2 100644 --- a/cond_test.go +++ b/cond_test.go @@ -230,3 +230,16 @@ func TestCondExpr(t *testing.T) { a.Equal(actual, expected) } } + +func TestCondMisuse(t *testing.T) { + a := assert.New(t) + + cond := NewCond() + sb := Select("*"). + From("t1"). + Where(cond.Equal("a", 123)) + sql, args := sb.Build() + + a.Equal(sql, "SELECT * FROM t1 WHERE /* INVALID ARG $256 */") + a.Equal(args, nil) +}