Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions go/vt/vitessdriver/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@ limitations under the License.
package vitessdriver

import (
"database/sql"
"database/sql/driver"
"io"
"reflect"
"time"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/proto/query"
)

// rows creates a database/sql/driver compliant Row iterator
Expand Down Expand Up @@ -58,3 +62,60 @@ func (ri *rows) Next(dest []driver.Value) error {
ri.index++
return nil
}

var (
typeInt8 = reflect.TypeOf(int8(0))
typeUint8 = reflect.TypeOf(uint8(0))
typeInt16 = reflect.TypeOf(int16(0))
typeUint16 = reflect.TypeOf(uint16(0))
typeInt32 = reflect.TypeOf(int32(0))
typeUint32 = reflect.TypeOf(uint32(0))
typeInt64 = reflect.TypeOf(int64(0))
typeUint64 = reflect.TypeOf(uint64(0))
typeFloat32 = reflect.TypeOf(float32(0))
typeFloat64 = reflect.TypeOf(float64(0))
typeRawBytes = reflect.TypeOf(sql.RawBytes{})
typeTime = reflect.TypeOf(time.Time{})
typeUnknown = reflect.TypeOf(new(interface{}))
)

// Implements the RowsColumnTypeScanType interface
func (ri *rows) ColumnTypeScanType(index int) reflect.Type {
field := ri.qr.Fields[index]
switch field.GetType() {
case query.Type_INT8:
return typeInt8
case query.Type_UINT8:
return typeUint8
case query.Type_INT16, query.Type_YEAR:
return typeInt16
case query.Type_UINT16:
return typeUint16
case query.Type_INT24:
return typeInt32
case query.Type_UINT24: // no 24 bit type, using 32 instead
return typeUint32
case query.Type_INT32:
return typeInt32
case query.Type_UINT32:
return typeUint32
case query.Type_INT64:
return typeInt64
case query.Type_UINT64:
return typeUint64
case query.Type_FLOAT32:
return typeFloat32
case query.Type_FLOAT64:
return typeFloat64
case query.Type_TIMESTAMP, query.Type_DECIMAL, query.Type_VARCHAR, query.Type_TEXT,
query.Type_BLOB, query.Type_VARBINARY, query.Type_CHAR, query.Type_BINARY, query.Type_BIT,
query.Type_ENUM, query.Type_SET, query.Type_TUPLE, query.Type_GEOMETRY, query.Type_JSON,
query.Type_HEXNUM, query.Type_HEXVAL, query.Type_BITNUM:

return typeRawBytes
case query.Type_DATE, query.Type_TIME, query.Type_DATETIME:
return typeTime
default:
return typeUnknown
}
}
91 changes: 91 additions & 0 deletions go/vt/vitessdriver/rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ package vitessdriver

import (
"database/sql/driver"
"fmt"
"io"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -135,3 +137,92 @@ func TestRows(t *testing.T) {

_ = ri.Close()
}

// Test that the ColumnTypeScanType function returns the correct reflection type for each
// sql type. The sql type in turn comes from a table column's type.
func TestColumnTypeScanType(t *testing.T) {
var r = sqltypes.Result{
Fields: []*querypb.Field{
{
Name: "field1",
Type: sqltypes.Int8,
},
{
Name: "field2",
Type: sqltypes.Uint8,
},
{
Name: "field3",
Type: sqltypes.Int16,
},
{
Name: "field4",
Type: sqltypes.Uint16,
},
{
Name: "field5",
Type: sqltypes.Int24,
},
{
Name: "field6",
Type: sqltypes.Uint24,
},
{
Name: "field7",
Type: sqltypes.Int32,
},
{
Name: "field8",
Type: sqltypes.Uint32,
},
{
Name: "field9",
Type: sqltypes.Int64,
},
{
Name: "field10",
Type: sqltypes.Uint64,
},
{
Name: "field11",
Type: sqltypes.Float32,
},
{
Name: "field12",
Type: sqltypes.Float64,
},
{
Name: "field13",
Type: sqltypes.VarBinary,
},
{
Name: "field14",
Type: sqltypes.Datetime,
},
},
}

ri := newRows(&r, &converter{}).(driver.RowsColumnTypeScanType)
defer ri.Close()

wantTypes := []reflect.Type{
typeInt8,
typeUint8,
typeInt16,
typeUint16,
typeInt32,
typeUint32,
typeInt32,
typeUint32,
typeInt64,
typeUint64,
typeFloat32,
typeFloat64,
typeRawBytes,
typeTime,
}

for i := 0; i < len(wantTypes); i++ {
assert.Equal(t, ri.ColumnTypeScanType(i), wantTypes[i], fmt.Sprintf("unexpected type %v, wanted %v", ri.ColumnTypeScanType(i), wantTypes[i]))
}
}