diff --git a/go/mysql/datetime/time_zone.go b/go/mysql/datetime/time_zone.go new file mode 100644 index 00000000000..046e06ed240 --- /dev/null +++ b/go/mysql/datetime/time_zone.go @@ -0,0 +1,79 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datetime + +import ( + "fmt" + "strconv" + "time" + + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +func unknownTimeZone(tz string) error { + return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.UnknownTimeZone, "Unknown or incorrect time zone: '%s'", tz) +} + +func ParseTimeZone(tz string) (*time.Location, error) { + // Needs to be checked first since time.LoadLocation("") returns UTC. + if tz == "" { + return nil, unknownTimeZone(tz) + } + loc, err := time.LoadLocation(tz) + if err == nil { + return loc, nil + } + + // MySQL also handles timezone formats in the form of the + // offset from UTC, so we'll try that if the above fails. + // This format is always something in the form of +HH:MM or -HH:MM. + if len(tz) != 6 { + return nil, unknownTimeZone(tz) + } + if tz[0] != '+' && tz[0] != '-' { + return nil, unknownTimeZone(tz) + } + if tz[3] != ':' { + return nil, unknownTimeZone(tz) + } + neg := tz[0] == '-' + hours, err := strconv.ParseUint(tz[1:3], 10, 4) + if err != nil { + return nil, unknownTimeZone(tz) + } + minutes, err := strconv.ParseUint(tz[4:], 10, 6) + if err != nil { + return nil, unknownTimeZone(tz) + } + if minutes > 59 { + return nil, unknownTimeZone(tz) + } + + // MySQL only supports timezones in the range of -13:59 to +14:00. + if neg && hours > 13 { + return nil, unknownTimeZone(tz) + } + if !neg && (hours > 14 || hours == 14 && minutes > 0) { + return nil, unknownTimeZone(tz) + } + offset := int(hours)*60*60 + int(minutes)*60 + if neg { + offset = -offset + } + return time.FixedZone(fmt.Sprintf("UTC%s", tz), offset), nil +} diff --git a/go/mysql/datetime/time_zone_test.go b/go/mysql/datetime/time_zone_test.go new file mode 100644 index 00000000000..94745d0c71e --- /dev/null +++ b/go/mysql/datetime/time_zone_test.go @@ -0,0 +1,77 @@ +/* +Copyright 2019 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datetime + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseTimeZone(t *testing.T) { + testCases := []struct { + tz string + want string + }{ + { + tz: "Europe/Amsterdam", + want: "Europe/Amsterdam", + }, + { + tz: "", + want: "Unknown or incorrect time zone: ''", + }, + { + tz: "+02:00", + want: "UTC+02:00", + }, + { + tz: "+14:00", + want: "UTC+14:00", + }, + { + tz: "+14:01", + want: "Unknown or incorrect time zone: '+14:01'", + }, + { + tz: "-13:59", + want: "UTC-13:59", + }, + { + tz: "-14:00", + want: "Unknown or incorrect time zone: '-14:00'", + }, + { + tz: "-15:00", + want: "Unknown or incorrect time zone: '-15:00'", + }, + { + tz: "foo", + want: "Unknown or incorrect time zone: 'foo'", + }, + } + + for _, tc := range testCases { + + zone, err := ParseTimeZone(tc.tz) + if err != nil { + assert.Equal(t, tc.want, err.Error()) + } else { + assert.Equal(t, tc.want, zone.String()) + } + } +} diff --git a/go/mysql/sql_error.go b/go/mysql/sql_error.go index 369b486c048..83a52e7e3f7 100644 --- a/go/mysql/sql_error.go +++ b/go/mysql/sql_error.go @@ -210,6 +210,7 @@ var stateToMysqlCode = map[vterrors.State]mysqlCode{ vterrors.NoSuchSession: {num: ERUnknownComError, state: SSNetError}, vterrors.OperandColumns: {num: EROperandColumns, state: SSWrongNumberOfColumns}, vterrors.WrongValueCountOnRow: {num: ERWrongValueCountOnRow, state: SSWrongValueCountOnRow}, + vterrors.UnknownTimeZone: {num: ERUnknownTimeZone, state: SSUnknownSQLState}, } func getStateToMySQLState(state vterrors.State) mysqlCode { diff --git a/go/vt/vterrors/state.go b/go/vt/vterrors/state.go index ae5a4970d2b..4bc963cc32b 100644 --- a/go/vt/vterrors/state.go +++ b/go/vt/vterrors/state.go @@ -83,6 +83,9 @@ const ( // server not available ServerNotAvailable + // unknown timezone + UnknownTimeZone + // No state should be added below NumOfStates NumOfStates ) diff --git a/go/vt/vtgate/safe_session.go b/go/vt/vtgate/safe_session.go index f39fd1176bc..1fa60e0a12d 100644 --- a/go/vt/vtgate/safe_session.go +++ b/go/vt/vtgate/safe_session.go @@ -25,6 +25,8 @@ import ( "google.golang.org/protobuf/proto" + "vitess.io/vitess/go/mysql/datetime" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/sysvars" @@ -553,9 +555,9 @@ func (session *SafeSession) TimeZone() *time.Location { session.mu.Unlock() if !ok { - return time.Local + return nil } - loc, _ := time.LoadLocation(tz) + loc, _ := datetime.ParseTimeZone(tz) return loc } diff --git a/go/vt/vtgate/safe_session_test.go b/go/vt/vtgate/safe_session_test.go index 4bcd095362c..21bb2d6697a 100644 --- a/go/vt/vtgate/safe_session_test.go +++ b/go/vt/vtgate/safe_session_test.go @@ -19,7 +19,9 @@ package vtgate import ( "reflect" "testing" + "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" querypb "vitess.io/vitess/go/vt/proto/query" @@ -64,3 +66,35 @@ func TestPrequeries(t *testing.T) { t.Errorf("got %v but wanted %v", preQueries, want) } } + +func TestTimeZone(t *testing.T) { + testCases := []struct { + tz string + want string + }{ + { + tz: "Europe/Amsterdam", + want: "Europe/Amsterdam", + }, + { + tz: "+02:00", + want: "UTC+02:00", + }, + { + tz: "foo", + want: (*time.Location)(nil).String(), + }, + } + + for _, tc := range testCases { + t.Run(tc.tz, func(t *testing.T) { + session := NewSafeSession(&vtgatepb.Session{ + SystemVariables: map[string]string{ + "time_zone": tc.tz, + }, + }) + + assert.Equal(t, tc.want, session.TimeZone().String()) + }) + } +}