diff --git a/db-auto-backup.py b/db-auto-backup.py index d289e3b..b881339 100755 --- a/db-auto-backup.py +++ b/db-auto-backup.py @@ -20,6 +20,7 @@ class BackupProvider(NamedTuple): + name: str patterns: list[str] backup_method: Callable[[Container], str] file_extension: str @@ -128,17 +129,22 @@ def backup_redis(container: Container) -> str: BACKUP_PROVIDERS: list[BackupProvider] = [ BackupProvider( + name="postgres", patterns=["postgres", "tensorchord/pgvecto-rs", "nextcloud/aio-postgresql"], backup_method=backup_psql, file_extension="sql", ), BackupProvider( + name="mysql", patterns=["mysql", "mariadb", "linuxserver/mariadb"], backup_method=backup_mysql, file_extension="sql", ), BackupProvider( - patterns=["redis"], backup_method=backup_redis, file_extension="rdb" + name="redis", + patterns=["redis"], + backup_method=backup_redis, + file_extension="rdb", ), ] @@ -194,13 +200,15 @@ def backup(now: datetime) -> None: backup_command = backup_provider.backup_method(container) _, output = container.exec_run(backup_command, stream=True, demux=True) + description = f"{container.name} ({backup_provider.name})" + with open_file_compressed( backup_temp_file_path, COMPRESSION ) as backup_temp_file: with tqdm.wrapattr( backup_temp_file, method="write", - desc=container.name, + desc=description, disable=not SHOW_PROGRESS, ) as f: for stdout, _ in output: @@ -211,7 +219,7 @@ def backup(now: datetime) -> None: os.replace(backup_temp_file_path, backup_file) if not SHOW_PROGRESS: - print(container.name) + print(description) backed_up_containers.append(container.name) diff --git a/tests/tests.py b/tests/tests.py index 4d060fa..d993758 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -100,9 +100,29 @@ def test_uptime_kuma_success_hook_url(monkeypatch: Any) -> None: ("docker.io/postgres:14-alpine", "postgres"), ("ghcr.io/realorangeone/db-auto-backup:latest", "realorangeone/db-auto-backup"), ("theorangeone/db-auto-backup:latest:latest", "theorangeone/db-auto-backup"), + ("lscr.io/linuxserver/mariadb:latest", "linuxserver/mariadb"), ], ) def test_get_container_names(tag: str, name: str) -> None: container = MagicMock() container.image.tags = [tag] assert db_auto_backup.get_container_names(container) == {name} + + +@pytest.mark.parametrize( + "container_name,name", + [ + ("postgres", "postgres"), + ("mysql", "mysql"), + ("mariadb", "mysql"), + ("linuxserver/mariadb", "mysql"), + ("tensorchord/pgvecto-rs", "postgres"), + ("nextcloud/aio-postgresql", "postgres"), + ("redis", "redis"), + ], +) +def test_get_backup_provider(container_name: str, name: str) -> None: + provider = db_auto_backup.get_backup_provider([container_name]) + + assert provider is not None + assert provider.name == name