|
8 | 8 | import json
|
9 | 9 | import os
|
10 | 10 | from concurrent.futures import ProcessPoolExecutor
|
| 11 | +from contextlib import contextmanager |
11 | 12 | from tempfile import TemporaryDirectory
|
12 | 13 | from unittest.mock import patch
|
13 | 14 |
|
|
32 | 33 | """
|
33 | 34 |
|
34 | 35 |
|
| 36 | +@contextmanager |
| 37 | +def temp_environ(**env): |
| 38 | + for key, value in env.items(): |
| 39 | + if value is None: |
| 40 | + os.environ.pop(key, None) |
| 41 | + else: |
| 42 | + os.environ[key] = value |
| 43 | + try: |
| 44 | + yield |
| 45 | + finally: |
| 46 | + for key, _ in env.items(): |
| 47 | + os.environ.pop(key, None) |
| 48 | + |
| 49 | + |
35 | 50 | class SampleExporter(TemplateExporter):
|
36 | 51 | """
|
37 | 52 | Exports a Python code file.
|
@@ -268,6 +283,24 @@ def test_absolute_template_dir(self):
|
268 | 283 | assert exporter.template_name == template
|
269 | 284 | assert os.path.join(td, template) in exporter.template_paths
|
270 | 285 |
|
| 286 | + def test_env_var_template_dir(self): |
| 287 | + with TemporaryDirectory() as td: |
| 288 | + template = "env-var" |
| 289 | + template_file = os.path.join(td, template, "index.py.j2") |
| 290 | + template_dir = os.path.dirname(template_file) |
| 291 | + os.mkdir(template_dir) |
| 292 | + test_output = "env-var!" |
| 293 | + with open(template_file, "w") as f: |
| 294 | + f.write(test_output) |
| 295 | + with temp_environ(NBCONVERT_EXTRA_TEMPLATE_BASEDIRS=template_dir): |
| 296 | + config = Config() |
| 297 | + config.TemplateExporter.template_name = template |
| 298 | + config.TemplateExporter.extra_template_basedirs = [td, template_dir] |
| 299 | + exporter = self._make_exporter(config=config) |
| 300 | + assert exporter.template.filename == template_file |
| 301 | + assert exporter.template_name == template |
| 302 | + assert template_dir in exporter.template_paths |
| 303 | + |
271 | 304 | def test_local_template_dir(self):
|
272 | 305 | with TemporaryDirectory() as td, _contextlib_chdir.chdir(td): # noqa
|
273 | 306 | with patch("os.getcwd", return_value=os.path.abspath(td)):
|
|
0 commit comments