Skip to content
This repository was archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dialect_mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
s := `select a.name as name, b.name as ctype,a.max_length,a.precision,a.scale,a.is_nullable as nullable,
"default_is_null" = (CASE WHEN c.text is null THEN 1 ELSE 0 END),
replace(replace(isnull(c.text,''),'(',''),')','') as vdefault,
ISNULL(i.is_primary_key, 0)
ISNULL(i.is_primary_key, 0), a.is_identity as is_identity
from sys.columns a
left join sys.types b on a.user_type_id=b.user_type_id
left join sys.syscomments c on a.default_object_id=c.id
Expand All @@ -362,8 +362,8 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
for rows.Next() {
var name, ctype, vdefault string
var maxLen, precision, scale int
var nullable, isPK, defaultIsNull bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK)
var nullable, isPK, defaultIsNull, isIncrement bool
err = rows.Scan(&name, &ctype, &maxLen, &precision, &scale, &nullable, &defaultIsNull, &vdefault, &isPK, &isIncrement)
if err != nil {
return nil, nil, err
}
Expand All @@ -377,6 +377,7 @@ func (db *mssql) GetColumns(tableName string) ([]string, map[string]*core.Column
col.Default = vdefault
}
col.IsPrimaryKey = isPK
col.IsAutoIncrement = isIncrement
ct := strings.ToUpper(ctype)
if ct == "DECIMAL" {
col.Length = precision
Expand Down
67 changes: 39 additions & 28 deletions dialect_sqlite3.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,40 @@ func splitColStr(colStr string) []string {
return results
}

func parseString(colStr string) (*core.Column, error) {
fields := splitColStr(colStr)
col := new(core.Column)
col.Indexes = make(map[string]int)
col.Nullable = true
col.DefaultIsEmpty = true

for idx, field := range fields {
if idx == 0 {
col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`)
continue
} else if idx == 1 {
col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0}
continue
}
switch field {
case "PRIMARY":
col.IsPrimaryKey = true
case "AUTOINCREMENT":
col.IsAutoIncrement = true
case "NULL":
if fields[idx-1] == "NOT" {
col.Nullable = false
} else {
col.Nullable = true
}
case "DEFAULT":
col.Default = fields[idx+1]
col.DefaultIsEmpty = false
}
}
return col, nil
}

func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
args := []interface{}{tableName}
s := "SELECT sql FROM sqlite_master WHERE type='table' and name = ?"
Expand Down Expand Up @@ -327,6 +361,7 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
colCreates := reg.FindAllString(name[nStart+1:nEnd], -1)
cols := make(map[string]*core.Column)
colSeq := make([]string, 0)

for _, colStr := range colCreates {
reg = regexp.MustCompile(`,\s`)
colStr = reg.ReplaceAllString(colStr, ",")
Expand All @@ -343,35 +378,11 @@ func (db *sqlite3) GetColumns(tableName string) ([]string, map[string]*core.Colu
continue
}

fields := splitColStr(colStr)
col := new(core.Column)
col.Indexes = make(map[string]int)
col.Nullable = true
col.DefaultIsEmpty = true

for idx, field := range fields {
if idx == 0 {
col.Name = strings.Trim(strings.Trim(field, "`[] "), `"`)
continue
} else if idx == 1 {
col.SQLType = core.SQLType{Name: field, DefaultLength: 0, DefaultLength2: 0}
}
switch field {
case "PRIMARY":
col.IsPrimaryKey = true
case "AUTOINCREMENT":
col.IsAutoIncrement = true
case "NULL":
if fields[idx-1] == "NOT" {
col.Nullable = false
} else {
col.Nullable = true
}
case "DEFAULT":
col.Default = fields[idx+1]
col.DefaultIsEmpty = false
}
col, err := parseString(colStr)
if err != nil {
return colSeq, cols, err
}

cols[col.Name] = col
colSeq = append(colSeq, col.Name)
}
Expand Down
49 changes: 49 additions & 0 deletions tag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,3 +549,52 @@ func TestSplitTag(t *testing.T) {
}
}
}

func TestTagAutoIncr(t *testing.T) {
assert.NoError(t, prepareEngine())

type TagAutoIncr struct {
Id int64
Name string
}

assertSync(t, new(TagAutoIncr))

tables, err := testEngine.DBMetas()
assert.NoError(t, err)
assert.EqualValues(t, 1, len(tables))
assert.EqualValues(t, tableMapper.Obj2Table("TagAutoIncr"), tables[0].Name)
col := tables[0].GetColumn(colMapper.Obj2Table("Id"))
assert.NotNil(t, col)
assert.True(t, col.IsPrimaryKey)
assert.True(t, col.IsAutoIncrement)

col2 := tables[0].GetColumn(colMapper.Obj2Table("Name"))
assert.NotNil(t, col2)
assert.False(t, col2.IsPrimaryKey)
assert.False(t, col2.IsAutoIncrement)
}

func TestTagPrimarykey(t *testing.T) {
assert.NoError(t, prepareEngine())
type TagPrimaryKey struct {
Id int64 `xorm:"pk"`
Name string `xorm:"VARCHAR(20) pk"`
}

assertSync(t, new(TagPrimaryKey))

tables, err := testEngine.DBMetas()
assert.NoError(t, err)
assert.EqualValues(t, 1, len(tables))
assert.EqualValues(t, tableMapper.Obj2Table("TagPrimaryKey"), tables[0].Name)
col := tables[0].GetColumn(colMapper.Obj2Table("Id"))
assert.NotNil(t, col)
assert.True(t, col.IsPrimaryKey)
assert.False(t, col.IsAutoIncrement)

col2 := tables[0].GetColumn(colMapper.Obj2Table("Name"))
assert.NotNil(t, col2)
assert.True(t, col2.IsPrimaryKey)
assert.False(t, col2.IsAutoIncrement)
}
8 changes: 7 additions & 1 deletion xorm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ import (

_ "github.com/denisenkom/go-mssqldb"
_ "github.com/go-sql-driver/mysql"
"xorm.io/core"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
"xorm.io/core"
)

var (
Expand All @@ -35,6 +35,9 @@ var (
splitter = flag.String("splitter", ";", "the splitter on connstr for cluster")
schema = flag.String("schema", "", "specify the schema")
ignoreSelectUpdate = flag.Bool("ignore_select_update", false, "ignore select update if implementation difference, only for tidb")

tableMapper core.IMapper
colMapper core.IMapper
)

func createEngine(dbType, connStr string) error {
Expand Down Expand Up @@ -122,6 +125,9 @@ func createEngine(dbType, connStr string) error {
}
}

tableMapper = testEngine.GetTableMapper()
colMapper = testEngine.GetColumnMapper()

tables, err := testEngine.DBMetas()
if err != nil {
return err
Expand Down