Skip to content
19 changes: 15 additions & 4 deletions go/libraries/doltcore/sqle/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -2049,7 +2049,7 @@ OuterLoop:
}, nil
}

// createSqlTable is the private version of CreateTable. It doesn't enforce any table name checks.
// createSqlTable is the private version of CreateTable.
func (db Database) createSqlTable(ctx *sql.Context, table string, schemaName string, sch sql.PrimaryKeySchema, collation sql.CollationID, comment string) error {
ws, err := db.GetWorkingSet(ctx)
if err != nil {
Expand All @@ -2066,6 +2066,8 @@ func (db Database) createSqlTable(ctx *sql.Context, table string, schemaName str
}

tableName := doltdb.TableName{Name: table, Schema: schemaName}
// TODO: This check is also done in createDoltTable, which is called at the end of this function, meaning it's done
// multiple times. Consider refactoring out.
if exists, err := root.HasTable(ctx, tableName); err != nil {
return err
} else if exists {
Expand Down Expand Up @@ -2118,6 +2120,8 @@ func (db Database) createIndexedSqlTable(ctx *sql.Context, table string, schemaN
}

tableName := doltdb.TableName{Name: table, Schema: schemaName}
// TODO: This check is also done in createDoltTable, which is called at the end of this function, meaning it's done
// multiple times. Consider refactoring out.
if exists, err := root.HasTable(ctx, tableName); err != nil {
return err
} else if exists {
Expand Down Expand Up @@ -2160,6 +2164,8 @@ func (db Database) createIndexedSqlTable(ctx *sql.Context, table string, schemaN

// createDoltTable creates a table on the database using the given dolt schema while not enforcing table baseName checks.
func (db Database) createDoltTable(ctx *sql.Context, tableName string, schemaName string, root doltdb.RootValue, doltSch schema.Schema) error {
// TODO: This check is also done in createSqlTable and createIndexedSqlTable, which both call createDoltTable,
// meaning it's done multiple times. Consider refactoring.
if exists, err := root.HasTable(ctx, doltdb.TableName{Name: tableName, Schema: schemaName}); err != nil {
return err
} else if exists {
Expand Down Expand Up @@ -2202,13 +2208,18 @@ func (db Database) CreateTemporaryTable(ctx *sql.Context, tableName string, pkSc
return ErrInvalidTableName.New(tableName)
}

tmp, err := NewTempTable(ctx, db.ddb, pkSch, tableName, db.Name(), db.editOpts, collation)
ds := dsess.DSessFromSess(ctx.Session)
databaseName := db.Name()
if _, exists := ds.GetTemporaryTable(ctx, databaseName, tableName); exists {
return sql.ErrTableAlreadyExists.New(tableName)
}

tmp, err := NewTempTable(ctx, db.ddb, pkSch, tableName, databaseName, db.editOpts, collation)
if err != nil {
return err
}

ds := dsess.DSessFromSess(ctx.Session)
ds.AddTemporaryTable(ctx, db.Name(), tmp)
ds.AddTemporaryTable(ctx, databaseName, tmp)
return nil
}

Expand Down
35 changes: 16 additions & 19 deletions go/libraries/doltcore/sqle/dsess/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ type DoltSession struct {
branchController *branch_control.Controller
dbCache *DatabaseCache
dbStates map[string]*DatabaseSessionState
tempTables map[string][]sql.Table
tempTables map[string]map[string]sql.Table
gcSafepointController *gcctx.GCSafepointController

writeSessProv WriteSessFunc
Expand All @@ -90,7 +90,7 @@ func DefaultSession(pro DoltDatabaseProvider, sessFunc WriteSessFunc) *DoltSessi
dbStates: make(map[string]*DatabaseSessionState),
dbCache: newDatabaseCache(),
provider: pro,
tempTables: make(map[string][]sql.Table),
tempTables: make(map[string]map[string]sql.Table),
globalsConf: config.NewMapConfig(make(map[string]string)),
branchController: branch_control.CreateDefaultController(context.TODO()), // Default sessions are fine with the default controller
mu: &sync.Mutex{},
Expand Down Expand Up @@ -121,7 +121,7 @@ func NewDoltSession(
dbStates: make(map[string]*DatabaseSessionState),
dbCache: newDatabaseCache(),
provider: pro,
tempTables: make(map[string][]sql.Table),
tempTables: make(map[string]map[string]sql.Table),
globalsConf: globals,
branchController: branchController,
statsProv: statsProvider,
Expand Down Expand Up @@ -1503,32 +1503,29 @@ func (d *DoltSession) DatabaseCache(ctx *sql.Context) *DatabaseCache {
}

func (d *DoltSession) AddTemporaryTable(ctx *sql.Context, db string, tbl sql.Table) {
d.tempTables[strings.ToLower(db)] = append(d.tempTables[strings.ToLower(db)], tbl)
db = strings.ToLower(db)
if _, exists := d.tempTables[db]; !exists {
d.tempTables[db] = make(map[string]sql.Table)
}
d.tempTables[db][strings.ToLower(tbl.Name())] = tbl
}

func (d *DoltSession) DropTemporaryTable(ctx *sql.Context, db, name string) {
tables := d.tempTables[strings.ToLower(db)]
for i, tbl := range d.tempTables[strings.ToLower(db)] {
if strings.EqualFold(tbl.Name(), name) {
tables = append(tables[:i], tables[i+1:]...)
break
}
}
d.tempTables[strings.ToLower(db)] = tables
delete(d.tempTables[strings.ToLower(db)], strings.ToLower(name))
}

func (d *DoltSession) GetTemporaryTable(ctx *sql.Context, db, name string) (sql.Table, bool) {
for _, tbl := range d.tempTables[strings.ToLower(db)] {
if strings.EqualFold(tbl.Name(), name) {
return tbl, true
}
}
return nil, false
tmpTbl, exists := d.tempTables[strings.ToLower(db)][strings.ToLower(name)]
return tmpTbl, exists
}

// GetAllTemporaryTables returns all temp tables for this session.
func (d *DoltSession) GetAllTemporaryTables(ctx *sql.Context, db string) ([]sql.Table, error) {
return d.tempTables[strings.ToLower(db)], nil
tmpTables := make([]sql.Table, 0)
for _, tbl := range d.tempTables[strings.ToLower(db)] {
tmpTables = append(tmpTables, tbl)
}
return tmpTables, nil
}

// CWBHeadRef returns the branch ref for this session HEAD for the database named
Expand Down
39 changes: 39 additions & 0 deletions go/libraries/doltcore/sqle/enginetest/dolt_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -8833,6 +8833,45 @@ var DoltSystemVariables = []queries.ScriptTest{
// DoltTempTableScripts tests temporary tables.
// Temporary tables are not supported in GMS, eventually should move those tests there.
var DoltTempTableScripts = []queries.ScriptTest{
{
Name: "temporary and non-temporary table name collisions",
SetUpScript: []string{
"create table t1 (id varchar(100))",
"insert into t1 values ('this is a non-temporary table')",
"create temporary table t1 (id varchar(100))", // okay to create temporary table with the same name as a non-temporary table
"insert into t1 values ('this is a temporary table')",
"create temporary table t2 (id int)",
"create table t2 (id int)", // okay to create a non-temporary table with the same name as a temporary table
"create temporary table t3 (id int)",
},
Assertions: []queries.ScriptTestAssertion{
{
// should show temporary table t1
Query: "select * from t1",
Expected: []sql.Row{{"this is a temporary table"}},
},
{
// cannot create a temporary table with the same name as another temporary table
Query: "create temporary table t3 (id int)",
ExpectedErr: sql.ErrTableAlreadyExists,
},
{
// should show temporary table
Query: "show create table t1",
Expected: []sql.Row{{"t1", "CREATE TEMPORARY TABLE `t1` (\n `id` varchar(100)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
// should drop temporary table t1
Query: "drop table t1",
Expected: []sql.Row{{types.NewOkResult(0)}},
},
{
// should show non-temporary table that was created before
Query: "show create table t1",
Expected: []sql.Row{{"t1", "CREATE TABLE `t1` (\n `id` varchar(100)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "temporary table supports auto increment",
SetUpScript: []string{
Expand Down
Loading