Skip to content

Commit d850527

Browse files
authored
cmd:新增dao命令 (#9)
1 parent e172a91 commit d850527

40 files changed

+2379
-251
lines changed

.CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@
77
- [internal/generate:添加select模板并将模板细化成文件](https://github.com/gotomicro/egen/pull/5)
88
- [internal/generate:添加update模板](https://github.com/gotomicro/egen/pull/6)
99
- [internal/model:新增ast读取文件](https://github.com/gotomicro/egen/pull/8)
10+
- [cmd/egen:egen dao命令](https://github.com/gotomicro/egen/pull/9)

.licenserc.json

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"**/*.go": "// Copyright 2021 gotomicro",
33
"**/*.{yml,toml}": "# Copyright 2021 gotomicro",
44
"ignore": [
5-
"internal/generate/testdata"
5+
"internal/generate/testdata",
6+
"internal/integration/dao",
7+
"internal/integration/generate"
68
]
79
}

LICENSE

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,4 +198,4 @@
198198
distributed under the License is distributed on an "AS IS" BASIS,
199199
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200200
See the License for the specific language governing permissions and
201-
limitations under the License.
201+
limitations under the License.

cmd/egen/dao/dao.go

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright 2021 gotomicro
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package daocmd
16+
17+
import (
18+
"flag"
19+
"fmt"
20+
"github.com/gotomicro/egen/internal/utils"
21+
"log"
22+
"os"
23+
"path/filepath"
24+
"strings"
25+
)
26+
27+
var (
28+
src, dst, dataModel, path string
29+
DaoFlagSet = initDaoFlagSet()
30+
tips = `
31+
-src file/dir -dst file/dir -type string -> 将src中的指定的type生成到./dst中
32+
-src file/dir -dst dir -> 将src中的所有type,生成对应的./dst/type_dao.go
33+
-src file/dir -type string -> 若src中存在该type,则生成./dao/type_dao.go
34+
-src file -> 将src中的所有type,生成对应的./dao/type_dao.go
35+
-src dir -> 扫描src下的所有go文件,若存在type,则生成type_dao.go.不会递归往下查找
36+
-dst file/dir -type string -> 若当前目录下存在该type,则生成./dst/type_dao.go
37+
-type string -> 在当前目录下的go文件,若存在type,则生成./dao/type_dao.go
38+
39+
`
40+
)
41+
42+
const (
43+
allModel = ""
44+
defaultDst = "./dao"
45+
defaultSrc = "."
46+
)
47+
48+
func ExecDao(args []string) {
49+
if len(args) < 1 {
50+
log.Println("将扫描当前目录下的所有go文件,并生成对应的type_dao.go")
51+
}
52+
DaoFlagSet.Parse(args)
53+
if err := initDao(src, dst, dataModel, path); err != nil {
54+
log.Println(err)
55+
}
56+
}
57+
58+
func initDao(src, dst, dataModel, path string) error {
59+
if dst == defaultDst || !utils.IsExist(dst) && !strings.HasSuffix(dst, ".go") {
60+
if err := os.MkdirAll(dst, 0666); err != nil {
61+
return err
62+
}
63+
} else if strings.HasSuffix(dst, ".go") && !utils.IsExist(dst) {
64+
dir, _ := filepath.Split(dst)
65+
if err := os.MkdirAll(dir, 0666); err != nil {
66+
return err
67+
}
68+
}
69+
return execWrite(src, dst, dataModel, path)
70+
}
71+
72+
func initDaoFlagSet() *flag.FlagSet {
73+
daoFlagSet := flag.NewFlagSet("dao", flag.ExitOnError)
74+
daoFlagSet.Usage = func() {
75+
fmt.Fprintf(os.Stderr, "Usage of dao:\n")
76+
daoFlagSet.PrintDefaults()
77+
fmt.Print(tips)
78+
}
79+
80+
daoFlagSet.SetOutput(os.Stdout)
81+
daoFlagSet.StringVar(&dst, "dst", defaultDst, "生成的代码写入的文件或目录")
82+
daoFlagSet.StringVar(&src, "src", defaultSrc, "读取结构体的文件或目录")
83+
daoFlagSet.StringVar(&dataModel, "type", allModel, "结构体名称")
84+
daoFlagSet.StringVar(&path, "import", "", "import时的路径")
85+
86+
return daoFlagSet
87+
}

cmd/egen/dao/exec.go

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright 2021 gotomicro
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package daocmd
16+
17+
import (
18+
"errors"
19+
"fmt"
20+
"github.com/gotomicro/egen/internal/generate"
21+
"github.com/gotomicro/egen/internal/model"
22+
"github.com/gotomicro/egen/internal/model/ast"
23+
"github.com/gotomicro/egen/internal/utils"
24+
"os"
25+
"path/filepath"
26+
"strings"
27+
)
28+
29+
func execWrite(src, dst, name, path string) error {
30+
var (
31+
dstDir = utils.IsDir(dst)
32+
srcDir = utils.IsDir(src)
33+
)
34+
35+
if !dstDir && name == allModel {
36+
return errors.New("-dst 应该是一个目录,或者使用 -type 指定了单个类型")
37+
} else if srcDir && !dstDir && name == allModel {
38+
return errors.New("-src为目录的情况下-dst也应为目录 或者使用 -type指定了单个类型")
39+
}
40+
41+
srcFiles := make([]string, 0, 10)
42+
if srcDir {
43+
files, err := os.ReadDir(src)
44+
if err != nil {
45+
return err
46+
}
47+
for _, file := range files {
48+
if strings.HasSuffix(file.Name(), ".go") {
49+
src, err = filepath.Abs(src)
50+
if err != nil {
51+
return err
52+
}
53+
srcFiles = append(srcFiles, src+"/"+file.Name())
54+
}
55+
}
56+
} else {
57+
src, err := filepath.Abs(src)
58+
if err != nil {
59+
return err
60+
}
61+
srcFiles = append(srcFiles, src)
62+
}
63+
64+
models := make([]model.Model, 0, len(srcFiles))
65+
for _, name := range srcFiles {
66+
models = append(models, ast.ParseModel(ast.LookUp(name, nil), model.WithImports(path))...)
67+
}
68+
69+
return WriteToFile(models, dst, name)
70+
}
71+
72+
func WriteToFile(models []model.Model, dst, name string) error {
73+
var mg generate.MySQLGenerator
74+
for _, v := range models {
75+
if name != "" && v.GoName != name {
76+
continue
77+
}
78+
79+
if utils.IsDir(dst) {
80+
// 可能要对多个文件进行写入 写入完成后直接close
81+
f, err := os.Create(dst + fmt.Sprintf("/%s_dao.go", ast.Convert(v.TableName)))
82+
if err != nil {
83+
f.Close() // 防止内存泄露
84+
return err
85+
}
86+
if err = mg.Generate(v, f); err != nil {
87+
f.Close() // 防止内存泄露
88+
return err
89+
}
90+
f.Close()
91+
fmt.Println(f.Name(), "已完成")
92+
} else {
93+
f, err := os.Create(dst)
94+
if err != nil {
95+
return err
96+
}
97+
98+
if err = mg.Generate(v, f); err != nil {
99+
return err
100+
}
101+
f.Close() // 只有单个文件进行写入 直接defer
102+
fmt.Println(f.Name(), "已完成")
103+
}
104+
}
105+
return nil
106+
}

cmd/egen/root.go

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Copyright 2021 gotomicro
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package cmd
16+
17+
import (
18+
"flag"
19+
"fmt"
20+
daocmd "github.com/gotomicro/egen/cmd/egen/dao"
21+
"os"
22+
)
23+
24+
var (
25+
longHelp = flag.Bool("help", false, "提供帮助")
26+
shortHelp = flag.Bool("h", false, "提供帮助")
27+
)
28+
29+
func Execute() {
30+
flag.Parse()
31+
if len(flag.Args()) > 0 {
32+
switch flag.Args()[0] {
33+
case "dao":
34+
daocmd.ExecDao(os.Args[2:])
35+
default:
36+
usage()
37+
}
38+
} else {
39+
usage()
40+
}
41+
}
42+
43+
func usage() {
44+
fmt.Println("提供以下几种命令:")
45+
daocmd.DaoFlagSet.Usage()
46+
}
47+
48+
func Help() {
49+
if *shortHelp || *longHelp {
50+
usage()
51+
}
52+
}

internal/generate/mysql.go

+3-16
Original file line numberDiff line numberDiff line change
@@ -16,34 +16,21 @@ package generate
1616

1717
import (
1818
"embed"
19-
"fmt"
2019
"github.com/gotomicro/egen/internal/model"
2120
"io"
22-
"log"
2321
"text/template"
2422
)
2523

26-
type MySQLGenerator struct {
27-
}
24+
type MySQLGenerator struct{}
2825

2926
//go:embed mysql_template
3027
var f embed.FS
3128

32-
func (*MySQLGenerator) Generate(m *model.Model, writer io.Writer) error {
29+
func (*MySQLGenerator) Generate(m model.Model, writer io.Writer) error {
3330
var err error
34-
fmt.Println(f)
35-
files := []string{"insert.gohtml", "select.gohtml", "update.gohtml", "delete.gohtml"}
3631
tMySQL, err := template.ParseFS(f, "mysql_template/*.gohtml")
3732
if err != nil {
38-
log.Println(err)
3933
return err
4034
}
41-
for _, v := range files {
42-
t := tMySQL.Lookup(v)
43-
err = t.Execute(writer, m)
44-
if err != nil {
45-
return err
46-
}
47-
}
48-
return nil
35+
return tMySQL.ExecuteTemplate(writer, "file.gohtml", &m)
4936
}

internal/generate/mysql_template/delete.gohtml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
2-
1+
{{- define "delete"}}
32
func (dao *{{.GoName}}DAO) DeleteByWhere(ctx context.Context, where string, args ...any) (int64, error) {
43
s := "DELETE FROM {{.QuotedTableName}} WHERE " + where
54
return dao.DeleteByRaw(ctx, s, args...)
@@ -12,3 +11,4 @@ func (dao *{{.GoName}}DAO) DeleteByRaw(ctx context.Context, query string, args .
1211
}
1312
return res.RowsAffected()
1413
}
14+
{{- end}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package code
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"strings"
7+
{{- if ne .ExtralImport ""}}
8+
"{{.ExtralImport}}"
9+
{{- end}}
10+
)
11+
12+
type {{.GoName}}DAO struct {
13+
DB *sql.DB
14+
}
15+
{{template "insert" .}}
16+
{{template "select" .}}
17+
{{template "update" .}}
18+
{{template "delete" .}}
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
1-
package code
2-
3-
import (
4-
"context"
5-
"database/sql"
6-
"strings"
7-
)
8-
9-
type {{.GoName}}DAO struct {
10-
DB *sql.DB
11-
}
12-
13-
func (dao *{{.GoName}}DAO) Insert(ctx context.Context, vals ...*{{.GoName}}) (int64, error) {
14-
var args = make([]interface{}, len(vals)*({{len .Fields}}))
1+
{{- define "insert"}}
2+
func (dao *{{.GoName}}DAO) Insert(ctx context.Context, vals ...*{{.PkgName}}{{.GoName}}) (int64, error) {
3+
if len(vals) == 0 || vals == nil {
4+
return 0, nil
5+
}
6+
var args = make([]interface{}, 0, len(vals)*({{len .Fields}}))
157
var str = ""
168
for k, v := range vals {
179
if k != 0 {
@@ -26,4 +18,5 @@ func (dao *{{.GoName}}DAO) Insert(ctx context.Context, vals ...*{{.GoName}}) (in
2618
return 0, err
2719
}
2820
return res.RowsAffected()
29-
}
21+
}
22+
{{- end}}

0 commit comments

Comments
 (0)