@@ -596,19 +596,30 @@ async def handle( # type: ignore
596596 conn_factory , repository , snapshot , logger
597597 )
598598
599+ internal_tables = RestoreInternalTables (
600+ conn_factory , repository , snapshot , logger
601+ )
599602 if restore_type == SnapshotRestoreType .ALL .value :
600- await self ._remove_gc_tables (conn_factory , repository , snapshot , logger )
603+ await internal_tables .remove_duplicated_tables ()
604+ elif restore_type == SnapshotRestoreType .TABLES .value :
605+ await internal_tables .remove_duplicated_tables (tables )
601606
602- await self ._start_restore_snapshot (
603- conn_factory ,
604- repository ,
605- snapshot ,
606- restore_type ,
607- logger ,
608- tables ,
609- partitions ,
610- sections ,
611- )
607+ try :
608+ await self ._start_restore_snapshot (
609+ conn_factory ,
610+ repository ,
611+ snapshot ,
612+ restore_type ,
613+ logger ,
614+ tables ,
615+ partitions ,
616+ sections ,
617+ )
618+ except kopf .PermanentError as e :
619+ await internal_tables .restore_tables ()
620+ raise e
621+ else :
622+ await internal_tables .cleanup_tables ()
612623
613624 @staticmethod
614625 async def _create_backup_repository (
@@ -707,51 +718,6 @@ async def _ensure_snapshot_exists(
707718 logger .warning ("DatabaseError in _ensure_snapshot_exists" , exc_info = e )
708719 raise kopf .PermanentError ("Snapshots could not be fetched." )
709720
710- @staticmethod
711- async def _remove_gc_tables (
712- conn_factory ,
713- repository : str ,
714- snapshot : str ,
715- logger : logging .Logger ,
716- ):
717- """
718- If the snapshot contains grand-central tables, remove them if they exist
719- in the cluster in order to recreate the new ones from the snapshot.
720-
721- :param conn_factory: A function that establishes a database connection to
722- the CrateDB cluster used for SQL queries.
723- :param repository: The name of the repository.
724- :param snapshot: The name of the snapshot to restore.
725- :param logger: the logger on which we're logging
726- """
727- logger .info ("Start _remove_gc_tables" )
728- try :
729- async with conn_factory () as conn :
730- async with conn .cursor (timeout = 120 ) as cursor :
731- await cursor .execute (
732- "WITH tables AS ("
733- " SELECT unnest(tables) AS t "
734- " FROM sys.snapshots "
735- " WHERE repository=%s AND name=%s"
736- ") "
737- "SELECT * FROM tables WHERE t LIKE 'gc.%%';" ,
738- (repository , snapshot ),
739- )
740- tables = await cursor .fetchall ()
741- logger .info (f"tables: { tables } " )
742- for (table ,) in tables :
743- logger .info (f"table: { table } " )
744- await cursor .execute (f"SELECT * FROM { table } LIMIT 1;" )
745- row = await cursor .fetchone ()
746- logger .info (f"row: { row } " )
747- if row :
748- logger .info (f"Dropping table: { table } " )
749- await cursor .execute (f"DROP TABLE { table } ;" )
750-
751- except DatabaseError as e :
752- logger .warning ("DatabaseError in _ensure_snapshot_exists" , exc_info = e )
753- raise kopf .PermanentError ("Snapshots could not be fetched." )
754-
755721 @staticmethod
756722 async def _start_restore_snapshot (
757723 conn_factory ,
@@ -1162,3 +1128,110 @@ async def handle( # type: ignore
11621128 name = name ,
11631129 body = body ,
11641130 )
1131+
1132+
1133+ class RestoreInternalTables :
1134+
1135+ def __init__ (
1136+ self ,
1137+ conn_factory ,
1138+ repository : str ,
1139+ snapshot : str ,
1140+ logger : logging .Logger ,
1141+ ):
1142+ self .conn_factory = conn_factory
1143+ self .repository : str = repository
1144+ self .snapshot : str = snapshot
1145+ self .logger : logging .Logger = logger
1146+
1147+ self .gc_tables_renamed : bool = False
1148+ self .gc_tables : list [str ] = []
1149+
1150+ async def remove_duplicated_tables (self , tables : Optional [List [str ]] = None ):
1151+ """
1152+ If the snapshot contains grand-central tables, rename them if they exist
1153+ in the cluster in order to recreate the new ones from the snapshot.
1154+ """
1155+ self .gc_tables_renamed = True
1156+ try :
1157+ async with self .conn_factory () as conn :
1158+ async with conn .cursor (timeout = 120 ) as cursor :
1159+ if tables is not None :
1160+ gc_tables = self .get_gc_tables (cursor , tables )
1161+ where_stmt = (
1162+ f"t IN ({ ',' .join (f"'{ table } '" for table in gc_tables )} )"
1163+ )
1164+ else :
1165+ where_stmt = "t LIKE 'gc.%%'"
1166+
1167+ await cursor .execute (
1168+ "WITH tables AS ("
1169+ " SELECT unnest(tables) AS t "
1170+ " FROM sys.snapshots "
1171+ " WHERE repository=%s AND name=%s"
1172+ ") "
1173+ f"SELECT * FROM tables WHERE { where_stmt } ;" ,
1174+ (self .repository , self .snapshot ),
1175+ )
1176+ tables = await cursor .fetchall ()
1177+ self .gc_tables = [table [0 ] for table in tables ] if tables else []
1178+ for table in self .gc_tables :
1179+ self .logger .info (f"Renaming GC table: { table } to { table } _temp" )
1180+ await cursor .execute (
1181+ f"ALTER TABLE { table } RENAME TO { table } _temp;"
1182+ )
1183+ except DatabaseError as e :
1184+ self .logger .warning (
1185+ "DatabaseError in RestoreInternalTables.remove_duplicated_tables" ,
1186+ exc_info = e ,
1187+ )
1188+ raise kopf .PermanentError ("internal tables couldn't be renamed." )
1189+
1190+ async def restore_tables (self ):
1191+ """
1192+ If the restore operation failed, rename back the gc tables
1193+ to their original names.
1194+ """
1195+ if self .gc_tables_renamed is False :
1196+ return
1197+
1198+ try :
1199+ async with self .conn_factory () as conn :
1200+ async with conn .cursor (timeout = 120 ) as cursor :
1201+ for table in self .gc_tables :
1202+ self .logger .info (f"Renaming GC table: { table } _temp to { table } " )
1203+ await cursor .execute (
1204+ f"ALTER TABLE { table } _temp RENAME TO { table } ;"
1205+ )
1206+ except DatabaseError as e :
1207+ self .logger .warning (
1208+ "DatabaseError in RestoreInternalTables.restore_tables" , exc_info = e
1209+ )
1210+ raise kopf .PermanentError ("internal table couldn't be renamed." )
1211+
1212+ async def cleanup_tables (self ):
1213+ """
1214+ After a successful restore, the temporary renamed gc tables can be dropped.
1215+ """
1216+ if self .gc_tables_renamed is False :
1217+ return
1218+
1219+ try :
1220+ async with self .conn_factory () as conn :
1221+ async with conn .cursor (timeout = 120 ) as cursor :
1222+ for table in self .gc_tables :
1223+ self .logger .info (f"Dropping old GC table: { table } _temp" )
1224+ await cursor .execute (f"DROP TABLE { table } _temp;" )
1225+ except DatabaseError as e :
1226+ self .logger .warning (
1227+ "DatabaseError in RestoreGCTables.restore_tables" , exc_info = e
1228+ )
1229+ raise kopf .PermanentError ("grand-central table couldn't be renamed." )
1230+
1231+ @staticmethod
1232+ def get_gc_tables (cursor , tables : list [str ]) -> list [str ]:
1233+ return [
1234+ quote_ident (table , cursor ._impl )
1235+ for table in tables
1236+ if table .startswith ("gc." )
1237+ ]
0 commit comments