diff --git a/src/subtitle_utils/__init__.py b/src/subtitle_utils/__init__.py index 5c54462..b48a301 100644 --- a/src/subtitle_utils/__init__.py +++ b/src/subtitle_utils/__init__.py @@ -4,6 +4,7 @@ # @Software: PyCharm # @Github :sudoskys import json +from io import TextIOBase from typing import Union, IO, Callable, Any from .convert.ass import AssConvert @@ -12,21 +13,21 @@ FOOTNOTE = None -def srt2bcc(content: Union[str, IO] +def srt2bcc(content: Union[str, IO, TextIOBase] ) -> str: result = BccConvert().srt2bcc(content=content, about=FOOTNOTE) result = json.dumps(result, ensure_ascii=False, indent=None) return result -def vtt2bcc(content: Union[str, IO] +def vtt2bcc(content: Union[str, IO, TextIOBase] ) -> str: result = BccConvert().vtt2bcc(content=content, about=FOOTNOTE) result = json.dumps(result, ensure_ascii=False, indent=None) return result -def ass2bcc(content: Union[str, IO] +def ass2bcc(content: Union[str, IO, TextIOBase] ) -> str: ass_result = AssConvert().ass2srt(content=content) result = BccConvert().srt2bcc(content=ass_result, about=FOOTNOTE) @@ -34,7 +35,7 @@ def ass2bcc(content: Union[str, IO] return result -def ass2srt(content: Union[str, IO] +def ass2srt(content: Union[str, IO, TextIOBase] ) -> str: """ :param content: @@ -44,7 +45,7 @@ def ass2srt(content: Union[str, IO] return result -def srt2ass(content: Union[str, IO], +def srt2ass(content: Union[str, IO, TextIOBase], *, header: str = None ) -> str: @@ -57,13 +58,13 @@ def srt2ass(content: Union[str, IO], return result -def bcc2srt(content: Union[str, IO], +def bcc2srt(content: Union[str, IO, TextIOBase], ) -> str: result = BccConvert().bcc2srt(content=content) return result -def bcc2ass(content: Union[str, IO] +def bcc2ass(content: Union[str, IO, TextIOBase] ) -> str: bcc_result = BccConvert().bcc2srt(content=content) result = AssConvert().srt2ass(content=bcc_result) diff --git a/src/subtitle_utils/convert/ass.py b/src/subtitle_utils/convert/ass.py index 49c0600..2acf10d 100644 --- a/src/subtitle_utils/convert/ass.py +++ b/src/subtitle_utils/convert/ass.py @@ -4,6 +4,7 @@ # @File : ass.py # @Software: PyCharm import tempfile +from io import TextIOBase from typing import Union, IO from pyasstosrt import Subtitle @@ -22,7 +23,7 @@ class AssConvert(Convert): @staticmethod - def srt2ass(content: Union[str, IO], + def srt2ass(content: Union[str, IO, TextIOBase], *, header: str = None) -> str: """ @@ -31,7 +32,7 @@ def srt2ass(content: Union[str, IO], :param header: ASS HEADER (Style) :return: processed subtitle """ - assert isinstance(content, (str, IO)), "content must be str or IO" + assert isinstance(content, (str, IO, TextIOBase)), f"content must be str or IO but {type(content)}" subs = SrtParse().parse(content=content) timestamps = [[str(sub.start), str(sub.end)] for sub in subs] subtitles = [sub.text for sub in subs] @@ -68,18 +69,21 @@ def srt2ass(content: Union[str, IO], return content @staticmethod - def ass2srt(content: Union[str, IO]) -> str: + def ass2srt(content: Union[str, IO, TextIOBase]) -> str: assert isinstance(content, (str, IO)), "content must be str or IO" # write to temp file - with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=True) as f: if isinstance(content, str): f.write(content) else: f.write(content.read()) - f.close() - sub = Subtitle(filepath=f.name) - dialog = sub.export(output_dialogues=True) + f.seek(0) _result = [] - for dialogue in dialog: - _result.append(str(dialogue)) + try: + sub = Subtitle(filepath=f.name) + dialog = sub.export(output_dialogues=True) + for dialogue in dialog: + _result.append(str(dialogue)) + finally: + f.close() return "".join(_result) diff --git a/src/subtitle_utils/convert/bcc.py b/src/subtitle_utils/convert/bcc.py index f2a59c9..e764d80 100644 --- a/src/subtitle_utils/convert/bcc.py +++ b/src/subtitle_utils/convert/bcc.py @@ -4,6 +4,7 @@ import re from datetime import datetime +from io import TextIOBase from typing import Union, IO from loguru import logger @@ -154,7 +155,7 @@ def _process_body(self, subs, about: str = None): def _time2str(time: float): return datetime.utcfromtimestamp(time).strftime("%H:%M:%S,%f")[:-3] - def srt2bcc(self, content: Union[str, IO], about: str = None): + def srt2bcc(self, content: Union[str, IO, TextIOBase], about: str = None): """ srt2bcc 将 srt 转换为 bcc B站字幕格式 :param content: srt format @@ -172,7 +173,7 @@ def srt2bcc(self, content: Union[str, IO], about: str = None): } return bcc if subs else {} - def bcc2srt(self, content: Union[str, IO]): + def bcc2srt(self, content: Union[str, IO, TextIOBase]): """ bcc2srt 将 bcc 转换为 srt 字幕格式 :param content: bcc format @@ -191,7 +192,7 @@ def bcc2srt(self, content: Union[str, IO]): srt += f"{content_str}\n\n" return srt[:-1] if subs else "" - def vtt2bcc(self, content: Union[str, IO], threshold=0.1, word=True, about: str = None): + def vtt2bcc(self, content: Union[str, IO, TextIOBase], threshold=0.1, word=True, about: str = None): """ vtt2bcc 将 vtt 转换为 bcc B站字幕格式 :param content: vtt format diff --git a/src/subtitle_utils/parse.py b/src/subtitle_utils/parse.py index f15bdb4..5255ad5 100644 --- a/src/subtitle_utils/parse.py +++ b/src/subtitle_utils/parse.py @@ -7,6 +7,7 @@ import os import tempfile from abc import ABC +from io import TextIOBase from typing import Union, IO import pysrt @@ -20,20 +21,23 @@ class Parser(ABC): Base Parser """ - def parse(self, content: Union[str, IO]): + def parse(self, content: Union[str, IO, TextIOBase]): raise NotImplementedError class SrtParse(Parser): - def parse(self, content: Union[str, IO]) -> SubRipFile: + def parse(self, content: Union[str, IO, TextIOBase]) -> SubRipFile: if isinstance(content, str): return pysrt.from_string(content) # write to temp file with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=True) as f: f.write(content.read()) - f.close() - return pysrt.open(f.name) + f.seek(0) + try: + return pysrt.open(path=f.name) + finally: + f.close() class BccParser(Parser): @@ -56,7 +60,7 @@ def parse_str(self, sentence): strs = sentence if sentence else "" return self._parse(strs) - def parse(self, content: Union[str, IO]) -> dict: + def parse(self, content: Union[str, IO, TextIOBase]) -> dict: """ Parse bcc :param content: str or IO @@ -69,7 +73,7 @@ def parse(self, content: Union[str, IO]) -> dict: class VttParser(Parser): - def parse(self, content: Union[str, IO]) -> WebVTTFile: + def parse(self, content: Union[str, IO, TextIOBase]) -> WebVTTFile: """ :param content: str or IO :return: pyvtt.WebVTTFile @@ -79,5 +83,8 @@ def parse(self, content: Union[str, IO]) -> WebVTTFile: # write to temp file with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=True) as f: f.write(content.read()) - f.close() - return pyvtt.open(f.name) + f.seek(0) + try: + return pyvtt.open(f.name) + finally: + f.close() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_function.py b/tests/test_function.py index ac97dc5..85ad96c 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -13,9 +13,11 @@ def test_show_available(): assert isinstance(show_available()[0], str), "Error Checking show list" -def test_srt2ass(): +def test_srt2bcc(): with open("test.bcc", "r") as f: bcc_exp = f.read() with open("test.srt", 'r') as file_io: - test_result = get_method(method="srt2ass")(content=file_io) + test_result = get_method(method="srt2bcc")(content=file_io) + test_result = test_result.replace("\n", "") + bcc_exp = bcc_exp.replace("\n", "") assert test_result == bcc_exp, f"Error Checking srt2ass \n{test_result}"