diff --git a/tests/decryption/test_decrypt.py b/tests/decryption/test_decrypt.py index c579eb6..94ed41e 100644 --- a/tests/decryption/test_decrypt.py +++ b/tests/decryption/test_decrypt.py @@ -8,10 +8,11 @@ import pytest from crypt4gh_middleware.decrypt import ( - get_private_keys, decrypt_files, + get_args, + get_private_keys, move_files, - get_args + remove_files, ) from tests.utils import patch_cli @@ -19,6 +20,16 @@ INPUT_TEXT = "hello world from the input!" +@pytest.fixture(name="files") +def fixture_files(tmp_path): + """Returns list of input file paths.""" + files = [INPUT_DIR / "hello.txt", INPUT_DIR / "hello.c4gh", INPUT_DIR / "alice.sec"] + temp_files = [tmp_path / file.name for file in files] + for src, dest in zip(files, temp_files): + shutil.copy(src, dest) + return temp_files + + class TestGetPrivateKeys: """Test get_private_keys.""" @@ -110,15 +121,6 @@ def file_contents_are_valid(): class TestMoveFiles: """Test move_files.""" - @pytest.fixture(name="files") - def fixture_files(self, tmp_path): - """Returns list of input file paths.""" - files = [INPUT_DIR/"hello.txt", INPUT_DIR/"hello.c4gh", INPUT_DIR/"alice.sec"] - temp_files = [tmp_path/file.name for file in files] - for src, dest in zip(files, temp_files): - shutil.copy(src, dest) - return temp_files - def test_empty_list(self, tmp_path): """Test that no error is thrown with an empty list.""" move_files(file_paths=[], output_dir=tmp_path) @@ -151,6 +153,35 @@ def test_permission_error(self, tmp_path): move_files(file_paths=[INPUT_DIR/"hello.txt"], output_dir=output_dir) +class TestRemoveFiles: + """Test remove_files.""" + + def test_remove_files(self, files, tmp_path): + """Test that the files in a directory are removed successfully.""" + assert all(file.exists() for file in files) + remove_files(tmp_path) + assert not any(file.exists() for file in files) + + def test_empty_dir(self, tmp_path): + """Test that no error is raised when an empty directory is passed.""" + empty_dir = tmp_path/"empty" + empty_dir.mkdir() + remove_files(empty_dir) + assert len(list(empty_dir.iterdir())) == 0 + + def test_dir_does_not_exist(self, tmp_path): + """Test that a value error is raised when a non-existent directory is passed.""" + with pytest.raises(ValueError): + remove_files(tmp_path/"bad_dir") + + def test_non_directory_path(self, tmp_path): + """Test that a value error is raised when a non-directory path is passed.""" + non_dir_path = tmp_path/"file.txt" + non_dir_path.touch() + with pytest.raises(ValueError): + remove_files(non_dir_path) + + class TestGetArgs: """Test get_args."""