@@ -15,52 +15,85 @@ import (
15
15
_ "github.com/mattn/go-sqlite3"
16
16
)
17
17
18
- func (s * SQLiteState ) migrateSchemaIfNecessary () (defErr error ) {
18
+ func initSQLiteDB (conn * sql.DB ) (defErr error ) {
19
+ // Start with a transaction to avoid "database locked" errors.
20
+ // See https://github.com/mattn/go-sqlite3/issues/274#issuecomment-1429054597
21
+ tx , err := conn .Begin ()
22
+ if err != nil {
23
+ return fmt .Errorf ("beginning transaction: %w" , err )
24
+ }
25
+ defer func () {
26
+ if defErr != nil {
27
+ if err := tx .Rollback (); err != nil {
28
+ logrus .Errorf ("Rolling back transaction to create tables: %v" , err )
29
+ }
30
+ }
31
+ }()
32
+
33
+ sameSchema , err := migrateSchemaIfNecessary (tx )
34
+ if err != nil {
35
+ return err
36
+ }
37
+ if ! sameSchema {
38
+ if err := createSQLiteTables (tx ); err != nil {
39
+ return err
40
+ }
41
+ }
42
+ if err := tx .Commit (); err != nil {
43
+ return fmt .Errorf ("committing transaction: %w" , err )
44
+ }
45
+ return nil
46
+ }
47
+
48
+ func migrateSchemaIfNecessary (tx * sql.Tx ) (bool , error ) {
19
49
// First, check if the DBConfig table exists
20
- checkRow := s . conn .QueryRow ("SELECT 1 FROM sqlite_master WHERE type='table' AND name='DBConfig';" )
50
+ checkRow := tx .QueryRow ("SELECT 1 FROM sqlite_master WHERE type='table' AND name='DBConfig';" )
21
51
var check int
22
52
if err := checkRow .Scan (& check ); err != nil {
23
53
if errors .Is (err , sql .ErrNoRows ) {
24
- return nil
54
+ return false , nil
25
55
}
26
- return fmt .Errorf ("checking if DB config table exists: %w" , err )
56
+ return false , fmt .Errorf ("checking if DB config table exists: %w" , err )
27
57
}
28
58
if check != 1 {
29
59
// Table does not exist, fresh database, no need to migrate.
30
- return nil
60
+ return false , nil
31
61
}
32
62
33
- row := s . conn .QueryRow ("SELECT SchemaVersion FROM DBConfig;" )
63
+ row := tx .QueryRow ("SELECT SchemaVersion FROM DBConfig;" )
34
64
var schemaVer int
35
65
if err := row .Scan (& schemaVer ); err != nil {
36
66
if errors .Is (err , sql .ErrNoRows ) {
37
67
// Brand-new, unpopulated DB.
38
68
// Schema was just created, so it has to be the latest.
39
- return nil
69
+ return false , nil
40
70
}
41
- return fmt .Errorf ("scanning schema version from DB config: %w" , err )
71
+ return false , fmt .Errorf ("scanning schema version from DB config: %w" , err )
42
72
}
43
73
44
74
// If the schema version 0 or less, it's invalid
45
75
if schemaVer <= 0 {
46
- return fmt .Errorf ("database schema version %d is invalid: %w" , schemaVer , define .ErrInternal )
76
+ return false , fmt .Errorf ("database schema version %d is invalid: %w" , schemaVer , define .ErrInternal )
47
77
}
48
78
49
- if schemaVer != schemaVersion {
50
- // If the DB is a later schema than we support, we have to error
51
- if schemaVer > schemaVersion {
52
- return fmt .Errorf ("database has schema version %d while this libpod version only supports version %d: %w" ,
53
- schemaVer , schemaVersion , define .ErrInternal )
54
- }
79
+ // Same schema -> nothing do to.
80
+ if schemaVer == schemaVersion {
81
+ return true , nil
82
+ }
55
83
56
- // Perform schema migration here, one version at a time.
84
+ // If the DB is a later schema than we support, we have to error
85
+ if schemaVer > schemaVersion {
86
+ return false , fmt .Errorf ("database has schema version %d while this libpod version only supports version %d: %w" ,
87
+ schemaVer , schemaVersion , define .ErrInternal )
57
88
}
58
89
59
- return nil
90
+ // Perform schema migration here, one version at a time.
91
+
92
+ return false , nil
60
93
}
61
94
62
95
// Initialize all required tables for the SQLite state
63
- func sqliteInitTables ( conn * sql.DB ) ( defErr error ) {
96
+ func createSQLiteTables ( tx * sql.Tx ) error {
64
97
// Technically we could split the "CREATE TABLE IF NOT EXISTS" and ");"
65
98
// bits off each command and add them in the for loop where we actually
66
99
// run the SQL, but that seems unnecessary.
@@ -186,28 +219,11 @@ func sqliteInitTables(conn *sql.DB) (defErr error) {
186
219
"VolumeState" : volumeState ,
187
220
}
188
221
189
- tx , err := conn .Begin ()
190
- if err != nil {
191
- return fmt .Errorf ("beginning transaction: %w" , err )
192
- }
193
- defer func () {
194
- if defErr != nil {
195
- if err := tx .Rollback (); err != nil {
196
- logrus .Errorf ("Rolling back transaction to create tables: %v" , err )
197
- }
198
- }
199
- }()
200
-
201
222
for tblName , cmd := range tables {
202
223
if _ , err := tx .Exec (cmd ); err != nil {
203
224
return fmt .Errorf ("creating table %s: %w" , tblName , err )
204
225
}
205
226
}
206
-
207
- if err := tx .Commit (); err != nil {
208
- return fmt .Errorf ("committing transaction: %w" , err )
209
- }
210
-
211
227
return nil
212
228
}
213
229
0 commit comments