diff --git a/altair_saver/_core.py b/altair_saver/_core.py index bf2b9a9..17704fb 100644 --- a/altair_saver/_core.py +++ b/altair_saver/_core.py @@ -1,4 +1,5 @@ from collections import OrderedDict +from pathlib import Path from typing import Any, Dict, IO, Iterable, Optional, Set, Type, Union import warnings @@ -73,7 +74,7 @@ def _select_saver( def save( chart: Union[alt.TopLevelMixin, JSONDict], - fp: Optional[Union[IO, str]] = None, + fp: Optional[Union[IO, str, Path]] = None, fmt: Optional[str] = None, mode: Optional[str] = None, embed_options: Optional[JSONDict] = None, diff --git a/altair_saver/_utils.py b/altair_saver/_utils.py index be5bb37..a4f7684 100644 --- a/altair_saver/_utils.py +++ b/altair_saver/_utils.py @@ -1,12 +1,13 @@ import contextlib -from http import client import io import os import socket import subprocess import sys import tempfile -from typing import Callable, IO, Iterator, List, Optional, Union +from http import client +from pathlib import Path +from typing import IO, Callable, Iterator, List, Optional, Union import altair as alt @@ -135,7 +136,7 @@ def temporary_filename( @contextlib.contextmanager def maybe_open(fp: Union[IO, str], mode: str = "w") -> Iterator[IO]: """Context manager to write to a file specified by filename or file-like object""" - if isinstance(fp, str): + if isinstance(fp, (str, Path)): with open(fp, mode) as f: yield f elif isinstance(fp, io.TextIOBase) and "b" in mode: diff --git a/altair_saver/savers/_saver.py b/altair_saver/savers/_saver.py index 4d9622c..b23c67e 100644 --- a/altair_saver/savers/_saver.py +++ b/altair_saver/savers/_saver.py @@ -1,16 +1,17 @@ import abc import json -from typing import Any, Dict, IO, Iterable, List, Optional, Union +from pathlib import Path +from typing import IO, Any, Dict, Iterable, List, Optional, Union import altair as alt -from altair_saver.types import Mimebundle, MimebundleContent, JSONDict from altair_saver._utils import ( extract_format, fmt_to_mimetype, infer_mode_from_spec, maybe_open, ) +from altair_saver.types import JSONDict, Mimebundle, MimebundleContent class Saver(metaclass=abc.ABCMeta): @@ -91,7 +92,7 @@ def mimebundle(self, fmts: Union[str, Iterable[str]]) -> Mimebundle: return bundle def save( - self, fp: Optional[Union[IO, str]] = None, fmt: Optional[str] = None + self, fp: Optional[Union[IO, str, Path]] = None, fmt: Optional[str] = None ) -> Optional[Union[str, bytes]]: """Save a chart to file