diff --git a/altair/utils/save.py b/altair/utils/save.py index 662cd001c..94ddab6f7 100644 --- a/altair/utils/save.py +++ b/altair/utils/save.py @@ -1,11 +1,13 @@ import json +import pathlib from .mimebundle import spec_to_mimebundle def write_file_or_filename(fp, content, mode="w"): - """Write content to fp, whether fp is a string or a file-like object""" - if isinstance(fp, str): + """Write content to fp, whether fp is a string, a pathlib Path or a + file-like object""" + if isinstance(fp, str) or isinstance(fp, pathlib.PurePath): with open(fp, mode) as f: f.write(content) else: @@ -34,8 +36,8 @@ def save( ---------- chart : alt.Chart the chart instance to save - fp : string filename or file-like object - file in which to write the chart. + fp : string filename, pathlib.Path or file-like object + file to which to write the chart. format : string (optional) the format to write: one of ['json', 'html', 'png', 'svg']. If not specified, the format will be determined from the filename. @@ -71,6 +73,8 @@ def save( if format is None: if isinstance(fp, str): format = fp.split(".")[-1] + elif isinstance(fp, pathlib.PurePath): + format = fp.suffix.lstrip(".") else: raise ValueError( "must specify file format: " "['png', 'svg', 'pdf', 'html', 'json']" diff --git a/altair/vegalite/v4/tests/test_api.py b/altair/vegalite/v4/tests/test_api.py index f0217888c..befb8d5ee 100644 --- a/altair/vegalite/v4/tests/test_api.py +++ b/altair/vegalite/v4/tests/test_api.py @@ -4,6 +4,7 @@ import json import operator import os +import pathlib import tempfile import jsonschema @@ -284,12 +285,14 @@ def test_save(format, basic_chart): fid, filename = tempfile.mkstemp(suffix="." + format) os.close(fid) - try: - basic_chart.save(filename) - with open(filename, mode) as f: - assert f.read()[:1000] == content[:1000] - finally: - os.remove(filename) + # test both string filenames and pathlib.Paths + for fp in [filename, pathlib.Path(filename)]: + try: + basic_chart.save(fp) + with open(fp, mode) as f: + assert f.read()[:1000] == content[:1000] + finally: + os.remove(fp) def test_facet_basic():