diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 5112bf32b2..7a1fe273dd 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -4428,6 +4428,14 @@ Select * from ( Query: "SELECT id FROM typestable WHERE da < date_add('2019-12-30', INTERVAL 1 DAY)", Expected: nil, }, + { + Query: "SELECT id FROM typestable WHERE da < adddate('2020-01-01', INTERVAL 1 DAY)", + Expected: []sql.Row{{int64(1)}}, + }, + { + Query: "SELECT id FROM typestable WHERE da < adddate('2020-01-01', 1)", + Expected: []sql.Row{{int64(1)}}, + }, { Query: "SELECT id FROM typestable WHERE ti > date_sub('2020-01-01', INTERVAL 1 DAY)", Expected: []sql.Row{{int64(1)}}, @@ -4456,6 +4464,18 @@ Select * from ( Query: "SELECT id FROM typestable WHERE da >= subdate('2020-01-01', 1)", Expected: []sql.Row{{int64(1)}}, }, + { + Query: "SELECT adddate(da, i32) from typestable;", + Expected: []sql.Row{{time.Date(2020, time.January, 4, 0, 0, 0, 0, time.UTC)}}, + }, + { + Query: "SELECT adddate(da, concat(u32)) from typestable;", + Expected: []sql.Row{{time.Date(2020, time.January, 8, 0, 0, 0, 0, time.UTC)}}, + }, + { + Query: "SELECT adddate(da, f32/10) from typestable;", + Expected: []sql.Row{{time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC)}}, + }, { Query: "SELECT subdate(da, i32) from typestable;", Expected: []sql.Row{{time.Date(2019, time.December, 27, 0, 0, 0, 0, time.UTC)}}, diff --git a/sql/expression/function/date.go b/sql/expression/function/date.go index 8498a21dad..749799fc4b 100644 --- a/sql/expression/function/date.go +++ b/sql/expression/function/date.go @@ -25,6 +25,26 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) +// NewAddDate returns a new function expression, or an error if one couldn't be created. The ADDDATE +// function is a synonym for DATE_ADD, with the one exception that if the second argument is NOT an +// explicitly declared interval, then the value is used and the interval period is assumed to be DAY. +// In either case, this function will actually return a *DateAdd struct. +func NewAddDate(args ...sql.Expression) (sql.Expression, error) { + if len(args) != 2 { + return nil, sql.ErrInvalidArgumentNumber.New("ADDDATE", 2, len(args)) + } + + // If the interval is explicitly specified, then we simply pass it all to DateSub + i, ok := args[1].(*expression.Interval) + if ok { + return &DateAdd{args[0], i}, nil + } + + // Otherwise, the interval period is assumed to be DAY + i = expression.NewInterval(args[1], "DAY") + return &DateAdd{args[0], i}, nil +} + // DateAdd adds an interval to a date. type DateAdd struct { Date sql.Expression @@ -91,12 +111,12 @@ func (d *DateAdd) WithChildren(children ...sql.Expression) (sql.Expression, erro // Eval implements the sql.Expression interface. func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - val, err := d.Date.Eval(ctx, row) + date, err := d.Date.Eval(ctx, row) if err != nil { return nil, err } - if val == nil { + if date == nil { return nil, nil } @@ -109,7 +129,7 @@ func (d *DateAdd) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - date, _, err := types.DatetimeMaxPrecision.Convert(val) + date, _, err = types.DatetimeMaxPrecision.Convert(date) if err != nil { ctx.Warn(1292, err.Error()) return nil, nil @@ -224,12 +244,6 @@ func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - date, _, err = types.DatetimeMaxPrecision.Convert(date) - if err != nil { - ctx.Warn(1292, err.Error()) - return nil, nil - } - delta, err := d.Interval.EvalDelta(ctx, row) if err != nil { return nil, err @@ -239,6 +253,12 @@ func (d *DateSub) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } + date, _, err = types.DatetimeMaxPrecision.Convert(date) + if err != nil { + ctx.Warn(1292, err.Error()) + return nil, nil + } + // return appropriate type res := types.ValidateTime(delta.Sub(date.(time.Time))) resType := d.Type() diff --git a/sql/expression/function/date_test.go b/sql/expression/function/date_test.go index 41c00e5f1c..8d8c52b44c 100644 --- a/sql/expression/function/date_test.go +++ b/sql/expression/function/date_test.go @@ -25,6 +25,64 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) +// MySQL's ADDDATE function is just syntactic sugar on top of DATE_ADD. The first param is the date, and the +// second is the value to add. If the second param is an interval type, it gets passed to DATE_ADD as-is. If +// it is not an explicit interval, the interval period is assumed to be "DAY". +func TestAddDate(t *testing.T) { + require := require.New(t) + ctx := sql.NewEmptyContext() + + // Not enough params + _, err := NewAddDate() + require.Error(err) + + // Not enough params + _, err = NewAddDate(expression.NewLiteral("2018-05-02", types.LongText)) + require.Error(err) + + // If the second argument is NOT an interval, then it's assumed to be a day interval + f, err := NewAddDate( + expression.NewLiteral("2018-05-02", types.LongText), + expression.NewLiteral(int64(1), types.Int64)) + require.NoError(err) + expected := time.Date(2018, time.May, 3, 0, 0, 0, 0, time.UTC) + result, err := f.Eval(ctx, sql.Row{}) + require.NoError(err) + require.Equal(expected, result) + + // If the second argument is an interval, then ADDDATE works exactly like DATE_ADD + f, err = NewAddDate( + expression.NewGetField(0, types.Text, "foo", false), + expression.NewInterval(expression.NewLiteral(int64(1), types.Int64), "DAY")) + require.NoError(err) + result, err = f.Eval(ctx, sql.Row{"2018-05-02"}) + require.NoError(err) + require.Equal(expected, result) + + // If the interval param is NULL, then NULL is returned + f2, err := NewAddDate( + expression.NewLiteral("2018-05-02", types.LongText), + expression.NewGetField(0, types.Int64, "foo", true)) + result, err = f2.Eval(ctx, sql.Row{nil}) + require.NoError(err) + require.Nil(result) + + // If the date param is NULL, then NULL is returned + result, err = f.Eval(ctx, sql.Row{nil}) + require.NoError(err) + require.Nil(result) + + // If a time is passed (and no date) then NULL is returned + result, err = f.Eval(ctx, sql.Row{"12:00:56"}) + require.NoError(err) + require.Nil(result) + + // If an invalid date is passed, then NULL is returned + result, err = f.Eval(ctx, sql.Row{"asdasdasd"}) + require.NoError(err) + require.Nil(result) +} + func TestDateAdd(t *testing.T) { require := require.New(t) ctx := sql.NewEmptyContext() diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 0f98a7be6b..4d5655885c 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -35,6 +35,7 @@ var BuiltIns = []sql.Function{ // elt, find_in_set, insert, load_file, locate sql.Function1{Name: "abs", Fn: NewAbsVal}, sql.Function1{Name: "acos", Fn: NewAcos}, + sql.FunctionN{Name: "adddate", Fn: NewAddDate}, sql.Function1{Name: "any_value", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewAnyValue(e) }}, sql.Function1{Name: "ascii", Fn: NewAscii}, sql.Function1{Name: "asin", Fn: NewAsin},