Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions tools/base/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib
import sys
from contextlib import contextmanager
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -119,39 +120,69 @@ def test_util_coverage_with_data_file(patches):
== [(m_open.return_value.__enter__.return_value,), {}])


def test_util_extract(patches):

@pytest.mark.parametrize(
"tarballs",
[(), tuple("TARB{i}" for i in range(0, 3))])
def test_util_extract(patches, tarballs):
patched = patches(
"tempfile.TemporaryDirectory",
"nested",
"tarfile.open",
prefix="tools.base.utils")

with patched as (m_tmp, m_open):
assert utils.extract("TARBALL", "PATH") == "PATH"
with patched as (m_nested, m_open):
_extractions = [MagicMock(), MagicMock()]
m_nested.return_value.__enter__.return_value = _extractions

if tarballs:
assert utils.extract("PATH", *tarballs) == "PATH"
else:

with pytest.raises(utils.ExtractError) as e:
utils.extract("PATH", *tarballs) == "PATH"

if not tarballs:
assert (
e.value.args[0]
== 'No tarballs specified for extraction to PATH')
assert not m_nested.called
assert not m_open.called
for _extract in _extractions:
assert not _extract.extractall.called
return

for _extract in _extractions:
assert (
list(_extract.extractall.call_args)
== [(), dict(path="PATH")])

assert (
list(m_open.call_args)
== [('TARBALL',), {}])
list(m_open.call_args_list)
== [[(tarb, ), {}] for tarb in tarballs])
assert (
list(m_open.return_value.__enter__.return_value.extractall.call_args)
== [(), {'path': "PATH"}])
list(m_nested.call_args)
== [tuple(m_open.return_value for x in tarballs), {}])


def test_util_untar(patches):
@pytest.mark.parametrize(
"tarballs",
[(), tuple("TARB{i}" for i in range(0, 3))])
def test_util_untar(patches, tarballs):
patched = patches(
"tempfile.TemporaryDirectory",
"extract",
prefix="tools.base.utils")

with patched as (m_tmp, m_extract):
with utils.untar("PATH") as tmpdir:
with utils.untar(*tarballs) as tmpdir:
assert tmpdir == m_extract.return_value

assert (
list(m_tmp.call_args)
== [(), {}])
assert (
list(m_extract.call_args)
== [('PATH', m_tmp.return_value.__enter__.return_value), {}])
== [(m_tmp.return_value.__enter__.return_value, ) + tarballs, {}])


def test_util_from_yaml(patches):
Expand Down
23 changes: 17 additions & 6 deletions tools/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import io
import os
import pathlib
import tarfile
import tempfile
from configparser import ConfigParser
Expand All @@ -13,6 +14,10 @@
import yaml


class ExtractError(Exception):
pass


# this is testing specific - consider moving to tools.testing.utils
@contextmanager
def coverage_with_data_file(data_file: str) -> Iterator[str]:
Expand Down Expand Up @@ -74,14 +79,20 @@ def buffered(
stderr.extend(mangle(_stderr.read().strip().split("\n")))


def extract(tarball: str, path: str) -> str:
with tarfile.open(tarball) as tarfiles:
tarfiles.extractall(path=path)
return path
def extract(path: Union[pathlib.Path, str], *tarballs: Union[pathlib.Path,
str]) -> Union[pathlib.Path, str]:
if not tarballs:
raise ExtractError(f"No tarballs specified for extraction to {path}")
openers = nested(*tuple(tarfile.open(tarball) for tarball in tarballs))

with openers as tarfiles:
for tar in tarfiles:
tar.extractall(path=path)
return path


@contextmanager
def untar(tarball: str) -> Iterator[str]:
def untar(*tarballs: str) -> Iterator[str]:
"""Untar a tarball into a temporary directory

for example to list the contents of a tarball:
Expand All @@ -102,7 +113,7 @@ def untar(tarball: str) -> Iterator[str]:

"""
with tempfile.TemporaryDirectory() as tmpdir:
yield extract(tarball, tmpdir)
yield extract(tmpdir, *tarballs)


def from_yaml(path: str) -> Union[dict, list, str, int]:
Expand Down
3 changes: 1 addition & 2 deletions tools/docs/sphinx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def rst_dir(self) -> str:
"""
rst_dir = os.path.join(self.build_dir, "generated/rst")
if self.rst_tar:
with tarfile.open(self.rst_tar) as tarfiles:
tarfiles.extractall(path=rst_dir)
utils.extract(rst_dir, self.rst_tar)
return rst_dir

@property
Expand Down
13 changes: 5 additions & 8 deletions tools/docs/tests/test_sphinx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ def test_sphinx_runner_rst_dir(patches, rst_tar):
runner = sphinx_runner.SphinxRunner()
patched = patches(
"os.path",
"tarfile",
"utils",
("SphinxRunner.build_dir", dict(new_callable=PropertyMock)),
("SphinxRunner.rst_tar", dict(new_callable=PropertyMock)),
prefix="tools.docs.sphinx_runner")

with patched as (m_path, m_tar, m_dir, m_rst):
with patched as (m_path, m_utils, m_dir, m_rst):
m_rst.return_value = rst_tar
assert runner.rst_dir == m_path.join.return_value

Expand All @@ -262,13 +262,10 @@ def test_sphinx_runner_rst_dir(patches, rst_tar):

if rst_tar:
assert (
list(m_tar.open.call_args)
== [(rst_tar,), {}])
assert (
list(m_tar.open.return_value.__enter__.return_value.extractall.call_args)
== [(), {'path': m_path.join.return_value}])
list(m_utils.extract.call_args)
== [(m_path.join.return_value, rst_tar), {}])
else:
assert not m_tar.open.called
assert not m_utils.extract.called
assert "rst_dir" in runner.__dict__


Expand Down