diff --git a/go/vt/vitessdriver/rows.go b/go/vt/vitessdriver/rows.go index d2ace7bdfad..a2438bb891c 100644 --- a/go/vt/vitessdriver/rows.go +++ b/go/vt/vitessdriver/rows.go @@ -17,10 +17,14 @@ limitations under the License. package vitessdriver import ( + "database/sql" "database/sql/driver" "io" + "reflect" + "time" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" ) // rows creates a database/sql/driver compliant Row iterator @@ -58,3 +62,60 @@ func (ri *rows) Next(dest []driver.Value) error { ri.index++ return nil } + +var ( + typeInt8 = reflect.TypeOf(int8(0)) + typeUint8 = reflect.TypeOf(uint8(0)) + typeInt16 = reflect.TypeOf(int16(0)) + typeUint16 = reflect.TypeOf(uint16(0)) + typeInt32 = reflect.TypeOf(int32(0)) + typeUint32 = reflect.TypeOf(uint32(0)) + typeInt64 = reflect.TypeOf(int64(0)) + typeUint64 = reflect.TypeOf(uint64(0)) + typeFloat32 = reflect.TypeOf(float32(0)) + typeFloat64 = reflect.TypeOf(float64(0)) + typeRawBytes = reflect.TypeOf(sql.RawBytes{}) + typeTime = reflect.TypeOf(time.Time{}) + typeUnknown = reflect.TypeOf(new(interface{})) +) + +// Implements the RowsColumnTypeScanType interface +func (ri *rows) ColumnTypeScanType(index int) reflect.Type { + field := ri.qr.Fields[index] + switch field.GetType() { + case query.Type_INT8: + return typeInt8 + case query.Type_UINT8: + return typeUint8 + case query.Type_INT16, query.Type_YEAR: + return typeInt16 + case query.Type_UINT16: + return typeUint16 + case query.Type_INT24: + return typeInt32 + case query.Type_UINT24: // no 24 bit type, using 32 instead + return typeUint32 + case query.Type_INT32: + return typeInt32 + case query.Type_UINT32: + return typeUint32 + case query.Type_INT64: + return typeInt64 + case query.Type_UINT64: + return typeUint64 + case query.Type_FLOAT32: + return typeFloat32 + case query.Type_FLOAT64: + return typeFloat64 + case query.Type_TIMESTAMP, query.Type_DECIMAL, query.Type_VARCHAR, query.Type_TEXT, + query.Type_BLOB, query.Type_VARBINARY, query.Type_CHAR, query.Type_BINARY, query.Type_BIT, + query.Type_ENUM, query.Type_SET, query.Type_TUPLE, query.Type_GEOMETRY, query.Type_JSON, + query.Type_HEXNUM, query.Type_HEXVAL, query.Type_BITNUM: + + return typeRawBytes + case query.Type_DATE, query.Type_TIME, query.Type_DATETIME: + return typeTime + default: + return typeUnknown + } +} diff --git a/go/vt/vitessdriver/rows_test.go b/go/vt/vitessdriver/rows_test.go index fdfc478ad16..13584e70dd8 100644 --- a/go/vt/vitessdriver/rows_test.go +++ b/go/vt/vitessdriver/rows_test.go @@ -18,10 +18,12 @@ package vitessdriver import ( "database/sql/driver" + "fmt" "io" "reflect" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -135,3 +137,92 @@ func TestRows(t *testing.T) { _ = ri.Close() } + +// Test that the ColumnTypeScanType function returns the correct reflection type for each +// sql type. The sql type in turn comes from a table column's type. +func TestColumnTypeScanType(t *testing.T) { + var r = sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "field1", + Type: sqltypes.Int8, + }, + { + Name: "field2", + Type: sqltypes.Uint8, + }, + { + Name: "field3", + Type: sqltypes.Int16, + }, + { + Name: "field4", + Type: sqltypes.Uint16, + }, + { + Name: "field5", + Type: sqltypes.Int24, + }, + { + Name: "field6", + Type: sqltypes.Uint24, + }, + { + Name: "field7", + Type: sqltypes.Int32, + }, + { + Name: "field8", + Type: sqltypes.Uint32, + }, + { + Name: "field9", + Type: sqltypes.Int64, + }, + { + Name: "field10", + Type: sqltypes.Uint64, + }, + { + Name: "field11", + Type: sqltypes.Float32, + }, + { + Name: "field12", + Type: sqltypes.Float64, + }, + { + Name: "field13", + Type: sqltypes.VarBinary, + }, + { + Name: "field14", + Type: sqltypes.Datetime, + }, + }, + } + + ri := newRows(&r, &converter{}).(driver.RowsColumnTypeScanType) + defer ri.Close() + + wantTypes := []reflect.Type{ + typeInt8, + typeUint8, + typeInt16, + typeUint16, + typeInt32, + typeUint32, + typeInt32, + typeUint32, + typeInt64, + typeUint64, + typeFloat32, + typeFloat64, + typeRawBytes, + typeTime, + } + + for i := 0; i < len(wantTypes); i++ { + assert.Equal(t, ri.ColumnTypeScanType(i), wantTypes[i], fmt.Sprintf("unexpected type %v, wanted %v", ri.ColumnTypeScanType(i), wantTypes[i])) + } +}