diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index ccbf46751c0..d541bf197ed 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -2049,7 +2049,7 @@ OuterLoop: }, nil } -// createSqlTable is the private version of CreateTable. It doesn't enforce any table name checks. +// createSqlTable is the private version of CreateTable. func (db Database) createSqlTable(ctx *sql.Context, table string, schemaName string, sch sql.PrimaryKeySchema, collation sql.CollationID, comment string) error { ws, err := db.GetWorkingSet(ctx) if err != nil { @@ -2066,6 +2066,8 @@ func (db Database) createSqlTable(ctx *sql.Context, table string, schemaName str } tableName := doltdb.TableName{Name: table, Schema: schemaName} + // TODO: This check is also done in createDoltTable, which is called at the end of this function, meaning it's done + // multiple times. Consider refactoring out. if exists, err := root.HasTable(ctx, tableName); err != nil { return err } else if exists { @@ -2118,6 +2120,8 @@ func (db Database) createIndexedSqlTable(ctx *sql.Context, table string, schemaN } tableName := doltdb.TableName{Name: table, Schema: schemaName} + // TODO: This check is also done in createDoltTable, which is called at the end of this function, meaning it's done + // multiple times. Consider refactoring out. if exists, err := root.HasTable(ctx, tableName); err != nil { return err } else if exists { @@ -2160,6 +2164,8 @@ func (db Database) createIndexedSqlTable(ctx *sql.Context, table string, schemaN // createDoltTable creates a table on the database using the given dolt schema while not enforcing table baseName checks. func (db Database) createDoltTable(ctx *sql.Context, tableName string, schemaName string, root doltdb.RootValue, doltSch schema.Schema) error { + // TODO: This check is also done in createSqlTable and createIndexedSqlTable, which both call createDoltTable, + // meaning it's done multiple times. Consider refactoring. if exists, err := root.HasTable(ctx, doltdb.TableName{Name: tableName, Schema: schemaName}); err != nil { return err } else if exists { @@ -2202,13 +2208,18 @@ func (db Database) CreateTemporaryTable(ctx *sql.Context, tableName string, pkSc return ErrInvalidTableName.New(tableName) } - tmp, err := NewTempTable(ctx, db.ddb, pkSch, tableName, db.Name(), db.editOpts, collation) + ds := dsess.DSessFromSess(ctx.Session) + databaseName := db.Name() + if _, exists := ds.GetTemporaryTable(ctx, databaseName, tableName); exists { + return sql.ErrTableAlreadyExists.New(tableName) + } + + tmp, err := NewTempTable(ctx, db.ddb, pkSch, tableName, databaseName, db.editOpts, collation) if err != nil { return err } - ds := dsess.DSessFromSess(ctx.Session) - ds.AddTemporaryTable(ctx, db.Name(), tmp) + ds.AddTemporaryTable(ctx, databaseName, tmp) return nil } diff --git a/go/libraries/doltcore/sqle/dsess/session.go b/go/libraries/doltcore/sqle/dsess/session.go index 3c0690b8321..787e7eff70f 100644 --- a/go/libraries/doltcore/sqle/dsess/session.go +++ b/go/libraries/doltcore/sqle/dsess/session.go @@ -64,7 +64,7 @@ type DoltSession struct { branchController *branch_control.Controller dbCache *DatabaseCache dbStates map[string]*DatabaseSessionState - tempTables map[string][]sql.Table + tempTables map[string]map[string]sql.Table gcSafepointController *gcctx.GCSafepointController writeSessProv WriteSessFunc @@ -90,7 +90,7 @@ func DefaultSession(pro DoltDatabaseProvider, sessFunc WriteSessFunc) *DoltSessi dbStates: make(map[string]*DatabaseSessionState), dbCache: newDatabaseCache(), provider: pro, - tempTables: make(map[string][]sql.Table), + tempTables: make(map[string]map[string]sql.Table), globalsConf: config.NewMapConfig(make(map[string]string)), branchController: branch_control.CreateDefaultController(context.TODO()), // Default sessions are fine with the default controller mu: &sync.Mutex{}, @@ -121,7 +121,7 @@ func NewDoltSession( dbStates: make(map[string]*DatabaseSessionState), dbCache: newDatabaseCache(), provider: pro, - tempTables: make(map[string][]sql.Table), + tempTables: make(map[string]map[string]sql.Table), globalsConf: globals, branchController: branchController, statsProv: statsProvider, @@ -1503,32 +1503,29 @@ func (d *DoltSession) DatabaseCache(ctx *sql.Context) *DatabaseCache { } func (d *DoltSession) AddTemporaryTable(ctx *sql.Context, db string, tbl sql.Table) { - d.tempTables[strings.ToLower(db)] = append(d.tempTables[strings.ToLower(db)], tbl) + db = strings.ToLower(db) + if _, exists := d.tempTables[db]; !exists { + d.tempTables[db] = make(map[string]sql.Table) + } + d.tempTables[db][strings.ToLower(tbl.Name())] = tbl } func (d *DoltSession) DropTemporaryTable(ctx *sql.Context, db, name string) { - tables := d.tempTables[strings.ToLower(db)] - for i, tbl := range d.tempTables[strings.ToLower(db)] { - if strings.EqualFold(tbl.Name(), name) { - tables = append(tables[:i], tables[i+1:]...) - break - } - } - d.tempTables[strings.ToLower(db)] = tables + delete(d.tempTables[strings.ToLower(db)], strings.ToLower(name)) } func (d *DoltSession) GetTemporaryTable(ctx *sql.Context, db, name string) (sql.Table, bool) { - for _, tbl := range d.tempTables[strings.ToLower(db)] { - if strings.EqualFold(tbl.Name(), name) { - return tbl, true - } - } - return nil, false + tmpTbl, exists := d.tempTables[strings.ToLower(db)][strings.ToLower(name)] + return tmpTbl, exists } // GetAllTemporaryTables returns all temp tables for this session. func (d *DoltSession) GetAllTemporaryTables(ctx *sql.Context, db string) ([]sql.Table, error) { - return d.tempTables[strings.ToLower(db)], nil + tmpTables := make([]sql.Table, 0) + for _, tbl := range d.tempTables[strings.ToLower(db)] { + tmpTables = append(tmpTables, tbl) + } + return tmpTables, nil } // CWBHeadRef returns the branch ref for this session HEAD for the database named diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index f0436f4a796..213d91afd6b 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -8833,6 +8833,45 @@ var DoltSystemVariables = []queries.ScriptTest{ // DoltTempTableScripts tests temporary tables. // Temporary tables are not supported in GMS, eventually should move those tests there. var DoltTempTableScripts = []queries.ScriptTest{ + { + Name: "temporary and non-temporary table name collisions", + SetUpScript: []string{ + "create table t1 (id varchar(100))", + "insert into t1 values ('this is a non-temporary table')", + "create temporary table t1 (id varchar(100))", // okay to create temporary table with the same name as a non-temporary table + "insert into t1 values ('this is a temporary table')", + "create temporary table t2 (id int)", + "create table t2 (id int)", // okay to create a non-temporary table with the same name as a temporary table + "create temporary table t3 (id int)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + // should show temporary table t1 + Query: "select * from t1", + Expected: []sql.Row{{"this is a temporary table"}}, + }, + { + // cannot create a temporary table with the same name as another temporary table + Query: "create temporary table t3 (id int)", + ExpectedErr: sql.ErrTableAlreadyExists, + }, + { + // should show temporary table + Query: "show create table t1", + Expected: []sql.Row{{"t1", "CREATE TEMPORARY TABLE `t1` (\n `id` varchar(100)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + { + // should drop temporary table t1 + Query: "drop table t1", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + // should show non-temporary table that was created before + Query: "show create table t1", + Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n `id` varchar(100)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}}, + }, + }, + }, { Name: "temporary table supports auto increment", SetUpScript: []string{