@@ -194,13 +194,19 @@ def compile_alter_sql(self, table):
194
194
else :
195
195
default = ""
196
196
197
+ column_constraint = ""
198
+ if column .column_type == "enum" :
199
+ values = ", " .join (f"'{ x } '" for x in column .values )
200
+ column_constraint = f" CHECK({ column .name } IN ({ values } ))"
201
+
197
202
add_columns .append (
198
203
self .add_column_string ()
199
204
.format (
200
205
name = self .wrap_column (column .name ),
201
206
data_type = self .type_map .get (column .column_type , "" ),
202
207
length = length ,
203
208
constraint = "PRIMARY KEY" if column .primary else "" ,
209
+ column_constraint = column_constraint ,
204
210
nullable = "NULL" if column .is_null else "NOT NULL" ,
205
211
default = default ,
206
212
after = (" AFTER " + self .wrap_column (column ._after ))
@@ -263,12 +269,18 @@ def compile_alter_sql(self, table):
263
269
changed_sql = []
264
270
265
271
for name , column in table .changed_columns .items ():
272
+
273
+ column_constraint = ""
274
+ if column .column_type == "enum" :
275
+ values = ", " .join (f"'{ x } '" for x in column .values )
276
+ column_constraint = f" CHECK({ column .name } IN ({ values } ))"
266
277
changed_sql .append (
267
278
self .modify_column_string ()
268
279
.format (
269
280
name = self .wrap_column (name ),
270
281
data_type = self .type_map .get (column .column_type ),
271
- nullable = "NULL" if column .is_null else "NOT NULL" ,
282
+ column_constraint = column_constraint ,
283
+ constraint = "PRIMARY KEY" if column .primary else "" ,
272
284
length = "(" + str (column .length ) + ")"
273
285
if column .column_type not in self .types_without_lengths
274
286
else "" ,
@@ -380,13 +392,13 @@ def alter_format_add_foreign_key(self):
380
392
return "ALTER TABLE {table} {columns}"
381
393
382
394
def add_column_string (self ):
383
- return "ADD COLUMN {name} {data_type}{length} {nullable}{default} {constraint}"
395
+ return "ADD COLUMN {name} {data_type}{length}{column_constraint} {nullable}{default} {constraint}"
384
396
385
397
def drop_column_string (self ):
386
398
return "DROP COLUMN {name}"
387
399
388
400
def modify_column_string (self ):
389
- return "ALTER COLUMN {name} TYPE {data_type}{length}"
401
+ return "ALTER COLUMN {name} TYPE {data_type}{length}{column_constraint} {constraint} "
390
402
391
403
def rename_column_string (self ):
392
404
return "RENAME COLUMN {old} TO {to}"
0 commit comments