diff --git a/tools/base/tests/test_utils.py b/tools/base/tests/test_utils.py index dc403ec999eb4..bfbde71a2fd47 100644 --- a/tools/base/tests/test_utils.py +++ b/tools/base/tests/test_utils.py @@ -1,6 +1,7 @@ import importlib import sys from contextlib import contextmanager +from unittest.mock import MagicMock import pytest @@ -119,31 +120,61 @@ def test_util_coverage_with_data_file(patches): == [(m_open.return_value.__enter__.return_value,), {}]) -def test_util_extract(patches): + +@pytest.mark.parametrize( + "tarballs", + [(), tuple("TARB{i}" for i in range(0, 3))]) +def test_util_extract(patches, tarballs): patched = patches( - "tempfile.TemporaryDirectory", + "nested", "tarfile.open", prefix="tools.base.utils") - with patched as (m_tmp, m_open): - assert utils.extract("TARBALL", "PATH") == "PATH" + with patched as (m_nested, m_open): + _extractions = [MagicMock(), MagicMock()] + m_nested.return_value.__enter__.return_value = _extractions + + if tarballs: + assert utils.extract("PATH", *tarballs) == "PATH" + else: + + with pytest.raises(utils.ExtractError) as e: + utils.extract("PATH", *tarballs) == "PATH" + + if not tarballs: + assert ( + e.value.args[0] + == 'No tarballs specified for extraction to PATH') + assert not m_nested.called + assert not m_open.called + for _extract in _extractions: + assert not _extract.extractall.called + return + + for _extract in _extractions: + assert ( + list(_extract.extractall.call_args) + == [(), dict(path="PATH")]) assert ( - list(m_open.call_args) - == [('TARBALL',), {}]) + list(m_open.call_args_list) + == [[(tarb, ), {}] for tarb in tarballs]) assert ( - list(m_open.return_value.__enter__.return_value.extractall.call_args) - == [(), {'path': "PATH"}]) + list(m_nested.call_args) + == [tuple(m_open.return_value for x in tarballs), {}]) -def test_util_untar(patches): +@pytest.mark.parametrize( + "tarballs", + [(), tuple("TARB{i}" for i in range(0, 3))]) +def test_util_untar(patches, tarballs): patched = patches( "tempfile.TemporaryDirectory", "extract", prefix="tools.base.utils") with patched as (m_tmp, m_extract): - with utils.untar("PATH") as tmpdir: + with utils.untar(*tarballs) as tmpdir: assert tmpdir == m_extract.return_value assert ( @@ -151,7 +182,7 @@ def test_util_untar(patches): == [(), {}]) assert ( list(m_extract.call_args) - == [('PATH', m_tmp.return_value.__enter__.return_value), {}]) + == [(m_tmp.return_value.__enter__.return_value, ) + tarballs, {}]) def test_util_from_yaml(patches): diff --git a/tools/base/utils.py b/tools/base/utils.py index 379e8f4326333..01813dbb0eea9 100644 --- a/tools/base/utils.py +++ b/tools/base/utils.py @@ -4,6 +4,7 @@ import io import os +import pathlib import tarfile import tempfile from configparser import ConfigParser @@ -13,6 +14,10 @@ import yaml +class ExtractError(Exception): + pass + + # this is testing specific - consider moving to tools.testing.utils @contextmanager def coverage_with_data_file(data_file: str) -> Iterator[str]: @@ -74,14 +79,20 @@ def buffered( stderr.extend(mangle(_stderr.read().strip().split("\n"))) -def extract(tarball: str, path: str) -> str: - with tarfile.open(tarball) as tarfiles: - tarfiles.extractall(path=path) - return path +def extract(path: Union[pathlib.Path, str], *tarballs: Union[pathlib.Path, + str]) -> Union[pathlib.Path, str]: + if not tarballs: + raise ExtractError(f"No tarballs specified for extraction to {path}") + openers = nested(*tuple(tarfile.open(tarball) for tarball in tarballs)) + + with openers as tarfiles: + for tar in tarfiles: + tar.extractall(path=path) + return path @contextmanager -def untar(tarball: str) -> Iterator[str]: +def untar(*tarballs: str) -> Iterator[str]: """Untar a tarball into a temporary directory for example to list the contents of a tarball: @@ -102,7 +113,7 @@ def untar(tarball: str) -> Iterator[str]: """ with tempfile.TemporaryDirectory() as tmpdir: - yield extract(tarball, tmpdir) + yield extract(tmpdir, *tarballs) def from_yaml(path: str) -> Union[dict, list, str, int]: diff --git a/tools/docs/sphinx_runner.py b/tools/docs/sphinx_runner.py index 2550bd7c57299..cd211bdf1a996 100644 --- a/tools/docs/sphinx_runner.py +++ b/tools/docs/sphinx_runner.py @@ -113,8 +113,7 @@ def rst_dir(self) -> str: """ rst_dir = os.path.join(self.build_dir, "generated/rst") if self.rst_tar: - with tarfile.open(self.rst_tar) as tarfiles: - tarfiles.extractall(path=rst_dir) + utils.extract(rst_dir, self.rst_tar) return rst_dir @property diff --git a/tools/docs/tests/test_sphinx_runner.py b/tools/docs/tests/test_sphinx_runner.py index f85ddd4004342..9e2f1df491aab 100644 --- a/tools/docs/tests/test_sphinx_runner.py +++ b/tools/docs/tests/test_sphinx_runner.py @@ -247,12 +247,12 @@ def test_sphinx_runner_rst_dir(patches, rst_tar): runner = sphinx_runner.SphinxRunner() patched = patches( "os.path", - "tarfile", + "utils", ("SphinxRunner.build_dir", dict(new_callable=PropertyMock)), ("SphinxRunner.rst_tar", dict(new_callable=PropertyMock)), prefix="tools.docs.sphinx_runner") - with patched as (m_path, m_tar, m_dir, m_rst): + with patched as (m_path, m_utils, m_dir, m_rst): m_rst.return_value = rst_tar assert runner.rst_dir == m_path.join.return_value @@ -262,13 +262,10 @@ def test_sphinx_runner_rst_dir(patches, rst_tar): if rst_tar: assert ( - list(m_tar.open.call_args) - == [(rst_tar,), {}]) - assert ( - list(m_tar.open.return_value.__enter__.return_value.extractall.call_args) - == [(), {'path': m_path.join.return_value}]) + list(m_utils.extract.call_args) + == [(m_path.join.return_value, rst_tar), {}]) else: - assert not m_tar.open.called + assert not m_utils.extract.called assert "rst_dir" in runner.__dict__