diff --git a/cmd/gf/internal/cmd/cmd_z_unit_gen_dao_sharding_test.go b/cmd/gf/internal/cmd/cmd_z_unit_gen_dao_sharding_test.go index 5a3989667b5..df2a7a45df4 100644 --- a/cmd/gf/internal/cmd/cmd_z_unit_gen_dao_sharding_test.go +++ b/cmd/gf/internal/cmd/cmd_z_unit_gen_dao_sharding_test.go @@ -18,6 +18,92 @@ import ( "github.com/gogf/gf/cmd/gf/v2/internal/cmd/gendao" ) +// Test_Gen_Dao_Sharding_Overlapping tests the fix for issue #4603. +// When sharding patterns have overlapping prefixes (like "a_?", "a_b_?", "a_c_?"), +// longer (more specific) patterns should be matched first. +// https://github.com/gogf/gf/issues/4603 +func Test_Gen_Dao_Sharding_Overlapping(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + var ( + err error + db = testDB + tableA1 = "a_1" + tableA2 = "a_2" + tableAB1 = "a_b_1" + tableAB2 = "a_b_2" + tableAC1 = "a_c_1" + tableAC2 = "a_c_2" + sqlFilePath = gtest.DataPath(`gendao`, `sharding`, `sharding_overlapping.sql`) + ) + dropTableWithDb(db, tableA1) + dropTableWithDb(db, tableA2) + dropTableWithDb(db, tableAB1) + dropTableWithDb(db, tableAB2) + dropTableWithDb(db, tableAC1) + dropTableWithDb(db, tableAC2) + t.AssertNil(execSqlFile(db, sqlFilePath)) + defer dropTableWithDb(db, tableA1) + defer dropTableWithDb(db, tableA2) + defer dropTableWithDb(db, tableAB1) + defer dropTableWithDb(db, tableAB2) + defer dropTableWithDb(db, tableAC1) + defer dropTableWithDb(db, tableAC2) + + var ( + path = gfile.Temp(guid.S()) + group = "test" + in = gendao.CGenDaoInput{ + Path: path, + Link: link, + Group: group, + Prefix: "", + // Patterns with overlapping prefixes - order should not matter due to sorting fix + ShardingPattern: []string{ + `a_?`, // shortest, matches a_1, a_2 but also a_b_1, a_c_1 without fix + `a_b_?`, // longer, should match a_b_1, a_b_2 + `a_c_?`, // longer, should match a_c_1, a_c_2 + }, + } + ) + err = gutil.FillStructWithDefault(&in) + t.AssertNil(err) + + err = gfile.Mkdir(path) + t.AssertNil(err) + + pwd := gfile.Pwd() + err = gfile.Chdir(path) + t.AssertNil(err) + defer gfile.Chdir(pwd) + defer gfile.RemoveAll(path) + + _, err = gendao.CGenDao{}.Dao(ctx, in) + t.AssertNil(err) + + // Should generate 3 dao files: a.go, a_b.go, a_c.go (plus internal versions) + generatedFiles, err := gfile.ScanDir(path, "*.go", true) + t.AssertNil(err) + // 3 sharding groups * 4 files each (dao, internal, do, entity) = 12 files + t.Assert(len(generatedFiles), 12) + + var ( + daoAContent = gfile.GetContents(gfile.Join(path, "dao", "a.go")) + daoABContent = gfile.GetContents(gfile.Join(path, "dao", "a_b.go")) + daoACContent = gfile.GetContents(gfile.Join(path, "dao", "a_c.go")) + ) + + // Verify each sharding group has correct dao file generated + t.Assert(gstr.Contains(daoAContent, "aShardingHandler"), true) + t.Assert(gstr.Contains(daoAContent, "m.Sharding(gdb.ShardingConfig{"), true) + + t.Assert(gstr.Contains(daoABContent, "aBShardingHandler"), true) + t.Assert(gstr.Contains(daoABContent, "m.Sharding(gdb.ShardingConfig{"), true) + + t.Assert(gstr.Contains(daoACContent, "aCShardingHandler"), true) + t.Assert(gstr.Contains(daoACContent, "m.Sharding(gdb.ShardingConfig{"), true) + }) +} + func Test_Gen_Dao_Sharding(t *testing.T) { gtest.C(t, func(t *gtest.T) { var ( diff --git a/cmd/gf/internal/cmd/gendao/gendao.go b/cmd/gf/internal/cmd/gendao/gendao.go index bc6ad943ac5..777b24ea146 100644 --- a/cmd/gf/internal/cmd/gendao/gendao.go +++ b/cmd/gf/internal/cmd/gendao/gendao.go @@ -9,6 +9,7 @@ package gendao import ( "context" "fmt" + "sort" "strings" "github.com/olekukonko/tablewriter" @@ -240,13 +241,22 @@ func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) { newTableNames = make([]string, len(tableNames)) shardingNewTableSet = gset.NewStrSet() ) + // Sort sharding patterns by length descending, so that longer (more specific) patterns + // are matched first. This prevents shorter patterns like "a_?" from incorrectly matching + // tables that should match longer patterns like "a_b_?" or "a_c_?". + // https://github.com/gogf/gf/issues/4603 + sortedShardingPatterns := make([]string, len(in.ShardingPattern)) + copy(sortedShardingPatterns, in.ShardingPattern) + sort.Slice(sortedShardingPatterns, func(i, j int) bool { + return len(sortedShardingPatterns[i]) > len(sortedShardingPatterns[j]) + }) for i, tableName := range tableNames { newTableName := tableName for _, v := range removePrefixArray { newTableName = gstr.TrimLeftStr(newTableName, v, 1) } - if len(in.ShardingPattern) > 0 { - for _, pattern := range in.ShardingPattern { + if len(sortedShardingPatterns) > 0 { + for _, pattern := range sortedShardingPatterns { var ( match []string regPattern = gstr.Replace(pattern, "?", `(.+)`) @@ -262,10 +272,11 @@ func doGenDaoForArray(ctx context.Context, index int, in CGenDaoInput) { newTableName = gstr.Trim(newTableName, `_.-`) if shardingNewTableSet.Contains(newTableName) { tableNames[i] = "" - continue + break } // Add prefix to sharding table name, if not, the isSharding check would not match. shardingNewTableSet.Add(in.Prefix + newTableName) + break } } newTableName = in.Prefix + newTableName diff --git a/cmd/gf/internal/cmd/testdata/gendao/sharding/sharding_overlapping.sql b/cmd/gf/internal/cmd/testdata/gendao/sharding/sharding_overlapping.sql new file mode 100644 index 00000000000..6e6a8e9289c --- /dev/null +++ b/cmd/gf/internal/cmd/testdata/gendao/sharding/sharding_overlapping.sql @@ -0,0 +1,47 @@ +-- Test case for issue #4603: overlapping sharding patterns +-- https://github.com/gogf/gf/issues/4603 +-- +-- Patterns: "a_?", "a_b_?", "a_c_?" +-- Expected: a_1/a_2 -> "a", a_b_1/a_b_2 -> "a_b", a_c_1/a_c_2 -> "a_c" + +CREATE TABLE `a_1` +( + `id` int unsigned NOT NULL AUTO_INCREMENT, + `name` varchar(45) NOT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `a_2` +( + `id` int unsigned NOT NULL AUTO_INCREMENT, + `name` varchar(45) NOT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `a_b_1` +( + `id` int unsigned NOT NULL AUTO_INCREMENT, + `name` varchar(45) NOT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `a_b_2` +( + `id` int unsigned NOT NULL AUTO_INCREMENT, + `name` varchar(45) NOT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `a_c_1` +( + `id` int unsigned NOT NULL AUTO_INCREMENT, + `name` varchar(45) NOT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE `a_c_2` +( + `id` int unsigned NOT NULL AUTO_INCREMENT, + `name` varchar(45) NOT NULL, + PRIMARY KEY (`id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8;