diff --git a/sql/types/system_bool.go b/sql/types/system_bool.go index 2da977ad05..34cd251c11 100644 --- a/sql/types/system_bool.go +++ b/sql/types/system_bool.go @@ -36,6 +36,7 @@ type SystemBoolType struct { var _ sql.SystemVariableType = SystemBoolType{} var _ sql.CollationCoercible = SystemBoolType{} +var _ sql.NumberType = SystemBoolType{} // NewSystemBoolType returns a new systemBoolType. func NewSystemBoolType(varName string) sql.SystemVariableType { @@ -184,6 +185,21 @@ func (t SystemBoolType) Zero() interface{} { return int8(0) } +// IsNumericType implements the sql.NumberType interface. +func (t SystemBoolType) IsNumericType() bool { + return true +} + +// IsFloat implements the sql.NumberType interface. +func (t SystemBoolType) IsFloat() bool { + return false +} + +// DisplayWidth implements the sql.NumberType interface. +func (t SystemBoolType) DisplayWidth() int { + return t.UnderlyingType().(sql.NumberType).DisplayWidth() +} + // EncodeValue implements SystemVariableType interface. func (t SystemBoolType) EncodeValue(val interface{}) (string, error) { expectedVal, ok := val.(int8) diff --git a/sql/types/system_double.go b/sql/types/system_double.go index c3c920e2dc..abda70f835 100644 --- a/sql/types/system_double.go +++ b/sql/types/system_double.go @@ -37,6 +37,7 @@ type systemDoubleType struct { var _ sql.SystemVariableType = systemDoubleType{} var _ sql.CollationCoercible = systemDoubleType{} +var _ sql.NumberType = systemDoubleType{} // NewSystemDoubleType returns a new systemDoubleType. func NewSystemDoubleType(varName string, lowerbound, upperbound float64) sql.SystemVariableType { @@ -169,6 +170,21 @@ func (t systemDoubleType) Zero() interface{} { return float64(0) } +// IsNumericType implements the sql.NumberType interface. +func (t systemDoubleType) IsNumericType() bool { + return true +} + +// IsFloat implements the sql.NumberType interface. +func (t systemDoubleType) IsFloat() bool { + return true +} + +// DisplayWidth implements the sql.NumberType interface. +func (t systemDoubleType) DisplayWidth() int { + return t.UnderlyingType().(sql.NumberType).DisplayWidth() +} + // CollationCoercibility implements sql.CollationCoercible interface. func (systemDoubleType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 diff --git a/sql/types/system_int.go b/sql/types/system_int.go index 6d17582684..eae2c35194 100644 --- a/sql/types/system_int.go +++ b/sql/types/system_int.go @@ -38,6 +38,7 @@ type systemIntType struct { var _ sql.SystemVariableType = systemIntType{} var _ sql.CollationCoercible = systemIntType{} +var _ sql.NumberType = systemIntType{} // NewSystemIntType returns a new systemIntType. func NewSystemIntType(varName string, lowerbound, upperbound int64, negativeOne bool) sql.SystemVariableType { @@ -179,6 +180,21 @@ func (t systemIntType) Zero() interface{} { return int64(0) } +// IsNumericType implements the sql.NumberType interface. +func (t systemIntType) IsNumericType() bool { + return true +} + +// IsFloat implements the sql.NumberType interface. +func (t systemIntType) IsFloat() bool { + return false +} + +// DisplayWidth implements the sql.NumberType interface. +func (t systemIntType) DisplayWidth() int { + return t.UnderlyingType().(sql.NumberType).DisplayWidth() +} + // CollationCoercibility implements sql.CollationCoercible interface. func (systemIntType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 diff --git a/sql/types/system_types_test.go b/sql/types/system_types_test.go new file mode 100644 index 0000000000..6a64479ef3 --- /dev/null +++ b/sql/types/system_types_test.go @@ -0,0 +1,37 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/dolthub/go-mysql-server/sql" +) + +func TestSystemTypesImplementSqlTypeInterfaces(t *testing.T) { + assert.True(t, sql.IsNumberType(SystemBoolType{})) + assert.True(t, sql.IsNumberType(systemIntType{})) + assert.True(t, sql.IsNumberType(systemUintType{})) + assert.True(t, sql.IsNumberType(systemDoubleType{})) + + assert.False(t, sql.IsNumberType(systemEnumType{})) + assert.False(t, sql.IsNumberType(systemSetType{})) + assert.False(t, sql.IsNumberType(systemStringType{})) + + assert.True(t, sql.IsStringType(systemStringType{})) + assert.False(t, sql.IsStringType(SystemBoolType{})) +} diff --git a/sql/types/system_uint.go b/sql/types/system_uint.go index 859d0116b2..70754c935c 100644 --- a/sql/types/system_uint.go +++ b/sql/types/system_uint.go @@ -37,6 +37,7 @@ type systemUintType struct { var _ sql.SystemVariableType = systemUintType{} var _ sql.CollationCoercible = systemUintType{} +var _ sql.NumberType = systemUintType{} // NewSystemUintType returns a new systemUintType. func NewSystemUintType(varName string, lowerbound, upperbound uint64) sql.SystemVariableType { @@ -168,6 +169,21 @@ func (t systemUintType) Zero() interface{} { return uint64(0) } +// IsNumericType implements the sql.NumberType interface. +func (t systemUintType) IsNumericType() bool { + return true +} + +// IsFloat implements the sql.NumberType interface. +func (t systemUintType) IsFloat() bool { + return false +} + +// DisplayWidth implements the sql.NumberType interface. +func (t systemUintType) DisplayWidth() int { + return t.UnderlyingType().(sql.NumberType).DisplayWidth() +} + // CollationCoercibility implements sql.CollationCoercible interface. func (systemUintType) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5