From 4174e394f289649179b8849b5999fd71a694e815 Mon Sep 17 00:00:00 2001 From: Nate Kean <14845347+garlic-os@users.noreply.github.com> Date: Fri, 16 May 2025 19:15:13 -0500 Subject: [PATCH] Add type annotations --- markovify/chain.py | 106 ++++++++++++++++++++++++------- markovify/splitters.py | 7 ++- markovify/text.py | 140 +++++++++++++++++++++++++++++++---------- markovify/utils.py | 66 ++++++++++++++++--- requirements-dev.txt | 1 + test/test_basic.py | 54 ++++++++-------- test/test_combine.py | 28 ++++----- test/test_itertext.py | 8 +-- 8 files changed, 302 insertions(+), 108 deletions(-) diff --git a/markovify/chain.py b/markovify/chain.py index b19e8e2..d7a45da 100644 --- a/markovify/chain.py +++ b/markovify/chain.py @@ -3,12 +3,50 @@ import bisect import json import copy +from typing import ( + Callable, + Dict, + Generator, + Iterable, + List, + Tuple, + Type, + TypeVar, + Union, + cast, +) + BEGIN = "___BEGIN__" END = "___END__" -def accumulate(iterable, func=operator.add): +State = Tuple[str, ...] +NextDict = Dict[str, int] +NextCompiled = Tuple[List[str], List[int]] +ModelUncompiled = Dict[State, NextDict] +ModelCompiled = Dict[State, NextCompiled] +Model = Union[ModelUncompiled, ModelCompiled] + + +T = TypeVar("T") +ChainT = TypeVar("ChainT", bound="Chain") +AccT = TypeVar("AccT") +AccF = Callable[[AccT, AccT], AccT] + + +class ParamError(Exception): + pass + + +def cast_not_none(var: Union[T, None]) -> T: + return cast(T, var) + + +def accumulate( + iterable: Iterable[AccT], + func: AccF[AccT] = operator.add, +) -> Generator[AccT, None, None]: """ Cumulative calculations. (Summation, by default.) Via: https://docs.python.org/3/library/itertools.html#itertools.accumulate @@ -21,10 +59,10 @@ def accumulate(iterable, func=operator.add): yield total -def compile_next(next_dict): +def compile_next(next_dict: NextDict) -> NextCompiled: words = list(next_dict.keys()) cff = list(accumulate(next_dict.values())) - return [words, cff] + return words, cff class Chain: @@ -33,7 +71,12 @@ class Chain: For example: Sentences. """ - def __init__(self, corpus, state_size, model=None): + def __init__( + self, + corpus: Union[Iterable[List[str]], None], + state_size: int, + model: Union[ModelCompiled, None] = None, + ): """ `corpus`: A list of lists, where each outer list is a "run" of the process (e.g., a single sentence), and each inner list @@ -44,21 +87,25 @@ def __init__(self, corpus, state_size, model=None): `state_size`: An integer indicating the number of items the model uses to represent its state. For text generation, 2 or 3 are typical. """ + if corpus is None and model is None: + raise ParamError("Must provide either `corpus` or `model`.") self.state_size = state_size - self.model = model or self.build(corpus, self.state_size) + self.model = model or self.build(cast_not_none(corpus), self.state_size) self.compiled = (len(self.model) > 0) and ( - type(self.model[tuple([BEGIN] * state_size)]) == list + isinstance(self.model[tuple([BEGIN] * state_size)], (tuple, list)) ) if not self.compiled: self.precompute_begin_state() - def compile(self, inplace=False): + def compile(self, inplace: bool = False) -> "Chain": if self.compiled: if inplace: return self - return Chain(None, self.state_size, model=copy.deepcopy(self.model)) - mdict = { - state: compile_next(next_dict) for (state, next_dict) in self.model.items() + model = cast(ModelCompiled, self.model) + return Chain(None, self.state_size, model=copy.deepcopy(model)) + model = cast(ModelUncompiled, self.model) + mdict: ModelCompiled = { + state: compile_next(next_dict) for (state, next_dict) in model.items() } if not inplace: return Chain(None, self.state_size, model=mdict) @@ -66,7 +113,7 @@ def compile(self, inplace=False): self.compiled = True return self - def build(self, corpus, state_size): + def build(self, corpus: Iterable[List[str]], state_size: int) -> Model: """ Build a Python representation of the Markov model. Returns a dict of dicts where the keys of the outer dict represent all possible states, @@ -77,7 +124,7 @@ def build(self, corpus, state_size): # Using a DefaultDict here would be a lot more convenient, however the memory # usage is far higher. - model = {} + model: Model = {} for run in corpus: items = ([BEGIN] * state_size) + run + [END] @@ -93,33 +140,40 @@ def build(self, corpus, state_size): model[state][follow] += 1 return model - def precompute_begin_state(self): + def precompute_begin_state(self) -> None: """ Caches the summation calculation and available choices for BEGIN * state_size. Significantly speeds up chain generation on large corpora. Thanks, @schollz! """ + model = cast(ModelUncompiled, self.model) begin_state = tuple([BEGIN] * self.state_size) - choices, cumdist = compile_next(self.model[begin_state]) + choices, cumdist = compile_next(model[begin_state]) self.begin_cumdist = cumdist self.begin_choices = choices - def move(self, state): + def move(self, state: State) -> str: """ Given a state, choose the next item at random. """ if self.compiled: - choices, cumdist = self.model[state] + model = cast(ModelCompiled, self.model) + choices, cumdist = model[state] elif state == tuple([BEGIN] * self.state_size): choices = self.begin_choices cumdist = self.begin_cumdist else: - choices, weights = zip(*self.model[state].items()) + model = cast(ModelUncompiled, self.model) + choices = tuple(model[state].keys()) + weights = tuple(model[state].values()) cumdist = list(accumulate(weights)) r = random.random() * cumdist[-1] selection = choices[bisect.bisect(cumdist, r)] return selection - def gen(self, init_state=None): + def gen( + self, + init_state: Union[State, None] = None, + ) -> Generator[str, None, None]: """ Starting either with a naive BEGIN state, or the provided `init_state` (as a tuple), return a generator that will yield successive items @@ -133,7 +187,7 @@ def gen(self, init_state=None): yield next_word state = tuple(state[1:]) + (next_word,) - def walk(self, init_state=None): + def walk(self, init_state: Union[State, None] = None) -> List[str]: """ Return a list representing a single run of the Markov model, either starting with a naive BEGIN state, or the provided `init_state` @@ -141,14 +195,17 @@ def walk(self, init_state=None): """ return list(self.gen(init_state)) - def to_json(self): + def to_json(self) -> str: """ Dump the model as a JSON object, for loading later. """ return json.dumps(list(self.model.items())) @classmethod - def from_json(cls, json_thing): + def from_json( + cls: Type[ChainT], + json_thing: Union[str, Dict, List], + ) -> ChainT: """ Given a JSON object or JSON string that was created by `self.to_json`, return the corresponding markovify.Chain. @@ -165,5 +222,12 @@ def from_json(cls, json_thing): state_size = len(list(rehydrated.keys())[0]) + compiled = (len(rehydrated) > 0) and ( + isinstance(rehydrated[tuple([BEGIN] * state_size)], list) + ) + if compiled: + for state in rehydrated: + rehydrated[state] = tuple(rehydrated[state]) + inst = cls(None, state_size, rehydrated) return inst diff --git a/markovify/splitters.py b/markovify/splitters.py index 01e8a90..851c948 100644 --- a/markovify/splitters.py +++ b/markovify/splitters.py @@ -1,4 +1,5 @@ import re +from typing import List uppercase_letter_pat = re.compile(r"^[A-Z]$", re.UNICODE) initialism_pat = re.compile(r"^[A-Za-z0-9]{1,2}(\.[A-Za-z0-9]{1,2})+\.$", re.UNICODE) @@ -22,7 +23,7 @@ abbr_lowercase = "etc|v|vs|viz|al|pct".split("|") -def is_abbreviation(dotted_word): +def is_abbreviation(dotted_word: str) -> bool: clipped = dotted_word[:-1] if re.match(uppercase_letter_pat, clipped[0]): if len(clipped) == 1: # Initial @@ -38,7 +39,7 @@ def is_abbreviation(dotted_word): return False -def is_sentence_ender(word): +def is_sentence_ender(word: str) -> bool: if re.match(initialism_pat, word) is not None: return False if word[-1] in ["?", "!"]: @@ -50,7 +51,7 @@ def is_sentence_ender(word): return False -def split_into_sentences(text): +def split_into_sentences(text: str) -> List[str]: potential_end_pat = re.compile( r"".join( [ diff --git a/markovify/text.py b/markovify/text.py index d608393..b382292 100644 --- a/markovify/text.py +++ b/markovify/text.py @@ -2,15 +2,56 @@ import re import json import random +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + Iterator, + List, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +if TYPE_CHECKING: + from typing_extensions import TypedDict, Unpack + from .splitters import split_into_sentences -from .chain import Chain, BEGIN +from .chain import Chain, BEGIN, State from unidecode import unidecode + DEFAULT_MAX_OVERLAP_RATIO = 0.7 DEFAULT_MAX_OVERLAP_TOTAL = 15 DEFAULT_TRIES = 10 +T = TypeVar("T") +TextT = TypeVar("TextT", bound="Text") + + +def cast_not_none(var: Union[T, None]) -> T: + return cast(T, var) + + +if TYPE_CHECKING: + + class TextMarkovifyDict(TypedDict): + state_size: int + chain: str + parsed_sentences: Union[List[List[str]], None] + + class MakeSentenceKwargs(TypedDict, total=False): + tries: int + max_overlap_ratio: float + max_overlap_total: int + test_output: bool + max_words: Union[int, None] + min_words: Union[int, None] + + class ParamError(Exception): pass @@ -20,13 +61,13 @@ class Text: def __init__( self, - input_text, - state_size=2, - chain=None, - parsed_sentences=None, - retain_original=True, - well_formed=True, - reject_reg="", + input_text: Union[str, None], + state_size: int = 2, + chain: Union[Chain, None] = None, + parsed_sentences: Union[List[List[str]], None] = None, + retain_original: bool = True, + well_formed: bool = True, + reject_reg: "Union[str, re.Pattern]" = "", ): """ input_text: A string. @@ -55,7 +96,7 @@ def __init__( if self.retain_original: self.parsed_sentences = parsed_sentences or list( - self.generate_corpus(input_text) + self.generate_corpus(cast_not_none(input_text)) ) # Rejoined text lets us assess the novelty of generated sentences @@ -64,11 +105,20 @@ def __init__( ) self.chain = chain or Chain(self.parsed_sentences, state_size) else: - if not chain: - parsed = parsed_sentences or self.generate_corpus(input_text) - self.chain = chain or Chain(parsed, state_size) + if chain is None: + if not can_make_sentences: + raise ParamError( + "Must provide either `input_text`, `parsed_sentences`, " + "or `chain`." + ) + parsed = parsed_sentences or self.generate_corpus( + cast_not_none(input_text) + ) + self.chain = Chain(parsed, state_size) + else: + self.chain = chain - def compile(self, inplace=False): + def compile(self, inplace: bool = False) -> "Text": if inplace: self.chain.compile(inplace=True) return self @@ -86,7 +136,7 @@ def compile(self, inplace=False): reject_reg=self.reject_pat, ) - def to_dict(self): + def to_dict(self) -> "TextMarkovifyDict": """ Returns the underlying data as a Python dict. """ @@ -96,14 +146,14 @@ def to_dict(self): "parsed_sentences": self.parsed_sentences if self.retain_original else None, } - def to_json(self): + def to_json(self) -> str: """ Returns the underlying data as a JSON string. """ return json.dumps(self.to_dict()) @classmethod - def from_dict(cls, obj, **kwargs): + def from_dict(cls: Type[TextT], obj: "TextMarkovifyDict", **kwargs) -> TextT: return cls( None, state_size=obj["state_size"], @@ -112,16 +162,16 @@ def from_dict(cls, obj, **kwargs): ) @classmethod - def from_json(cls, json_str): + def from_json(cls: Type[TextT], json_str: str) -> TextT: return cls.from_dict(json.loads(json_str)) - def sentence_split(self, text): + def sentence_split(self, text: str) -> List[str]: """ Splits full-text string into a list of sentences. """ return split_into_sentences(text) - def sentence_join(self, sentences): + def sentence_join(self, sentences: Iterable[str]) -> str: """ Re-joins a list of sentences into the full text. """ @@ -129,19 +179,19 @@ def sentence_join(self, sentences): word_split_pattern = re.compile(r"\s+") - def word_split(self, sentence): + def word_split(self, sentence: str) -> List[str]: """ Splits a sentence into a list of words. """ return re.split(self.word_split_pattern, sentence) - def word_join(self, words): + def word_join(self, words: Iterable[str]) -> str: """ Re-joins a list of words into a sentence. """ return " ".join(words) - def test_sentence_input(self, sentence): + def test_sentence_input(self, sentence: str) -> bool: """ A basic sentence filter. The default rejects sentences that contain the type of punctuation that would look strange on its own @@ -156,7 +206,7 @@ def test_sentence_input(self, sentence): return False return True - def generate_corpus(self, text): + def generate_corpus(self, text: str) -> Iterator[List[str]]: """ Given a text string, returns a list of lists; that is, a list of "sentences," each of which is a list of words. Before splitting into @@ -172,7 +222,12 @@ def generate_corpus(self, text): runs = map(self.word_split, passing) return runs - def test_sentence_output(self, words, max_overlap_ratio, max_overlap_total): + def test_sentence_output( + self, + words: List[str], + max_overlap_ratio: float, + max_overlap_total: int, + ) -> bool: """ Given a generated list of words, accept or reject it. This one rejects sentences that too closely match the original text, namely those that @@ -192,7 +247,11 @@ def test_sentence_output(self, words, max_overlap_ratio, max_overlap_total): return False return True - def make_sentence(self, init_state=None, **kwargs): + def make_sentence( + self, + init_state: Union[Tuple[str, ...], None] = None, + **kwargs: "Unpack[MakeSentenceKwargs]", + ) -> Union[str, None]: """ Attempts `tries` (default: 10) times to generate a valid sentence, based on the model and `test_sentence_output`. Passes `max_overlap_ratio` @@ -214,8 +273,8 @@ def make_sentence(self, init_state=None, **kwargs): mor = kwargs.get("max_overlap_ratio", DEFAULT_MAX_OVERLAP_RATIO) mot = kwargs.get("max_overlap_total", DEFAULT_MAX_OVERLAP_TOTAL) test_output = kwargs.get("test_output", True) - max_words = kwargs.get("max_words", None) - min_words = kwargs.get("min_words", None) + max_words = kwargs.get("max_words") + min_words = kwargs.get("min_words") if init_state is None: prefix = [] @@ -240,7 +299,12 @@ def make_sentence(self, init_state=None, **kwargs): return self.word_join(words) return None - def make_short_sentence(self, max_chars, min_chars=0, **kwargs): + def make_short_sentence( + self, + max_chars: int, + min_chars: int = 0, + **kwargs: "Unpack[MakeSentenceKwargs]", + ) -> Union[str, None]: """ Tries making a sentence of no more than `max_chars` characters and optionally no less than `min_chars` characters, passing **kwargs to `self.make_sentence`. @@ -252,7 +316,12 @@ def make_short_sentence(self, max_chars, min_chars=0, **kwargs): if sentence and min_chars <= len(sentence) <= max_chars: return sentence - def make_sentence_with_start(self, beginning, strict=True, **kwargs): + def make_sentence_with_start( + self, + beginning: str, + strict: bool = True, + **kwargs: "Unpack[MakeSentenceKwargs]", + ) -> str: """ Tries making a sentence that begins with `beginning` string, which should be a string of one to `self.state` words known @@ -298,7 +367,7 @@ def make_sentence_with_start(self, beginning, strict=True, **kwargs): raise ParamError(err_msg) @functools.lru_cache(maxsize=1) - def find_init_states_from_chain(self, split): + def find_init_states_from_chain(self, split: State) -> List[State]: """ Find all chains that begin with the split when `self.make_sentence_with_start` is called with strict == False. @@ -316,14 +385,19 @@ def find_init_states_from_chain(self, split): ] @classmethod - def from_chain(cls, chain_json, corpus=None, parsed_sentences=None): + def from_chain( + cls: Type[TextT], + chain_json: Union[str, Dict, List], + corpus: Union[str, None] = None, + parsed_sentences: Union[List[List[str]], None] = None, + ) -> TextT: """ Init a Text class based on an existing chain JSON string or object If corpus is None, overlap checking won't work. """ chain = Chain.from_json(chain_json) return cls( - corpus or None, + corpus, parsed_sentences=parsed_sentences, state_size=chain.state_size, chain=chain, @@ -336,5 +410,5 @@ class NewlineText(Text): text where the sentences are separated by newlines instead of ". " """ - def sentence_split(self, text): + def sentence_split(self, text: str) -> List[str]: return re.split(r"\s*\n\s*", text) diff --git a/markovify/utils.py b/markovify/utils.py index 91bc99f..cd7f621 100644 --- a/markovify/utils.py +++ b/markovify/utils.py @@ -1,16 +1,26 @@ -from .chain import Chain +from typing import TYPE_CHECKING, List, Sequence, TypeVar, Union, cast, overload +from .chain import Chain, ModelUncompiled from .text import Text +if TYPE_CHECKING: + from typing_extensions import assert_never +else: -def get_model_dict(thing): + def assert_never() -> None: + pass + + +def get_model_dict( + thing: Union[Chain, Text, List, ModelUncompiled], +) -> ModelUncompiled: if isinstance(thing, Chain): if thing.compiled: raise ValueError("Not implemented for compiled markovify.Chain") - return thing.model + return cast(ModelUncompiled, thing.model) if isinstance(thing, Text): if thing.chain.compiled: raise ValueError("Not implemented for compiled markovify.Chain") - return thing.chain.model + return cast(ModelUncompiled, thing.chain.model) if isinstance(thing, list): return dict(thing) if isinstance(thing, dict): @@ -21,7 +31,45 @@ def get_model_dict(thing): ) -def combine(models, weights=None): +T = TypeVar("T", Chain, Text, List, ModelUncompiled) + + +@overload +def combine( + models: Sequence[Chain], + weights: Union[List[int], None] = None, +) -> Chain: + ... + + +@overload +def combine( + models: Sequence[Text], + weights: Union[List[int], None] = None, +) -> Text: + ... + + +@overload +def combine( + models: Sequence[List], + weights: Union[List[int], None] = None, +) -> List: + ... + + +@overload +def combine( + models: Sequence[ModelUncompiled], + weights: Union[List[int], None] = None, +) -> ModelUncompiled: + ... + + +def combine( + models: Sequence[T], + weights: Union[List[int], None] = None, +) -> T: if weights is None: weights = [1 for _ in range(len(models))] @@ -37,7 +85,7 @@ def combine(models, weights=None): if len(set(map(type, models))) != 1: raise ValueError("All `models` must be of the same type.") - c = {} + c: ModelUncompiled = {} for m, w in zip(model_dicts, weights): for state, options in m.items(): @@ -53,9 +101,10 @@ def combine(models, weights=None): return Chain.from_json(c) if isinstance(ret_inst, Text): ret_inst.find_init_states_from_chain.cache_clear() - if any(m.retain_original for m in models): + text_models = cast(List[Text], models) + if any(m.retain_original for m in text_models): combined_sentences = [] - for m in models: + for m in text_models: if m.retain_original: combined_sentences += m.parsed_sentences return ret_inst.from_chain(c, parsed_sentences=combined_sentences) @@ -65,3 +114,4 @@ def combine(models, weights=None): return list(c.items()) if isinstance(ret_inst, dict): return c + assert_never(ret_inst) diff --git a/requirements-dev.txt b/requirements-dev.txt index 95c23a4..1c5a6cb 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,4 @@ flake8 pytest pytest-cov coveralls +typing_extensions diff --git a/test/test_basic.py b/test/test_basic.py index 506a483..8a0d89b 100644 --- a/test/test_basic.py +++ b/test/test_basic.py @@ -11,24 +11,24 @@ def get_sorted(chain_json): class MarkovifyTestBase(unittest.TestCase): __test__ = False - def test_text_too_small(self): + def test_text_too_small(self) -> None: text = "Example phrase. This is another example sentence." text_model = markovify.Text(text) assert text_model.make_sentence() is None - def test_sherlock(self): + def test_sherlock(self) -> None: text_model = self.sherlock_model sent = text_model.make_sentence() assert len(sent) != 0 - def test_json(self): + def test_json(self) -> None: text_model = self.sherlock_model json_model = text_model.to_json() new_text_model = markovify.Text.from_json(json_model) sent = new_text_model.make_sentence() assert len(sent) != 0 - def test_chain(self): + def test_chain(self) -> None: text_model = self.sherlock_model chain_json = text_model.chain.to_json() @@ -41,34 +41,36 @@ def test_chain(self): sent = new_text_model.make_sentence() assert len(sent) != 0 - def test_make_sentence_with_start(self): + def test_make_sentence_with_start(self) -> None: text_model = self.sherlock_model start_str = "Sherlock Holmes" sent = text_model.make_sentence_with_start(start_str) assert sent is not None assert start_str == sent[: len(start_str)] - def test_make_sentence_with_start_one_word(self): + def test_make_sentence_with_start_one_word(self) -> None: text_model = self.sherlock_model start_str = "Sherlock" sent = text_model.make_sentence_with_start(start_str) assert sent is not None assert start_str == sent[: len(start_str)] - def test_make_sentence_with_start_one_word_that_doesnt_begin_a_sentence(self): + def test_make_sentence_with_start_one_word_that_doesnt_begin_a_sentence( + self, + ) -> None: text_model = self.sherlock_model start_str = "dog" with self.assertRaises(KeyError): text_model.make_sentence_with_start(start_str) - def test_make_sentence_with_word_not_at_start_of_sentence(self): + def test_make_sentence_with_word_not_at_start_of_sentence(self) -> None: text_model = self.sherlock_model start_str = "dog" sent = text_model.make_sentence_with_start(start_str, strict=False) assert sent is not None assert start_str == sent[: len(start_str)] - def test_make_sentence_with_words_not_at_start_of_sentence(self): + def test_make_sentence_with_words_not_at_start_of_sentence(self) -> None: text_model = self.sherlock_model_ss3 # " I was " has 128 matches in sherlock.txt # " was I " has 2 matches in sherlock.txt @@ -77,26 +79,28 @@ def test_make_sentence_with_words_not_at_start_of_sentence(self): assert sent is not None assert start_str == sent[: len(start_str)] - def test_make_sentence_with_words_not_at_start_of_sentence_miss(self): + def test_make_sentence_with_words_not_at_start_of_sentence_miss(self) -> None: text_model = self.sherlock_model_ss3 start_str = "was werewolf" with self.assertRaises(markovify.text.ParamError): text_model.make_sentence_with_start(start_str, strict=False, tries=50) - def test_make_sentence_with_words_not_at_start_of_sentence_of_state_size(self): + def test_make_sentence_with_words_not_at_start_of_sentence_of_state_size( + self, + ) -> None: text_model = self.sherlock_model_ss2 start_str = "was I" sent = text_model.make_sentence_with_start(start_str, strict=False, tries=50) assert sent is not None assert start_str == sent[: len(start_str)] - def test_make_sentence_with_words_to_many(self): + def test_make_sentence_with_words_to_many(self) -> None: text_model = self.sherlock_model start_str = "dog is good" with self.assertRaises(markovify.text.ParamError): text_model.make_sentence_with_start(start_str, strict=False) - def test_make_sentence_with_start_three_words(self): + def test_make_sentence_with_start_three_words(self) -> None: start_str = "Sherlock Holmes was" text_model = self.sherlock_model try: @@ -111,36 +115,36 @@ def test_make_sentence_with_start_three_words(self): sent = text_model.make_sentence_with_start("Sherlock", tries=50) assert markovify.chain.BEGIN not in sent - def test_short_sentence(self): + def test_short_sentence(self) -> None: text_model = self.sherlock_model sent = None while sent is None: sent = text_model.make_short_sentence(45) assert len(sent) <= 45 - def test_short_sentence_min_chars(self): + def test_short_sentence_min_chars(self) -> None: sent = None while sent is None: sent = self.sherlock_model.make_short_sentence(100, min_chars=50) assert len(sent) <= 100 assert len(sent) >= 50 - def test_dont_test_output(self): + def test_dont_test_output(self) -> None: text_model = self.sherlock_model sent = text_model.make_sentence(test_output=False) assert sent is not None - def test_max_words(self): + def test_max_words(self) -> None: text_model = self.sherlock_model sent = text_model.make_sentence(max_words=0) assert sent is None - def test_min_words(self): + def test_min_words(self) -> None: text_model = self.sherlock_model sent = text_model.make_sentence(min_words=5) assert len(sent.split(" ")) >= 5 - def test_newline_text(self): + def test_newline_text(self) -> None: with open( os.path.join(os.path.dirname(__file__), "texts/senate-bills.txt"), encoding="utf-8", @@ -148,15 +152,15 @@ def test_newline_text(self): model = markovify.NewlineText(f.read()) model.make_sentence() - def test_bad_corpus(self): + def test_bad_corpus(self) -> None: with self.assertRaises(Exception): - markovify.Chain(corpus="testing, testing", state_size=2) + markovify.Chain(corpus="testing, testing", state_size=2) # type: ignore - def test_bad_json(self): + def test_bad_json(self) -> None: with self.assertRaises(Exception): - markovify.Chain.from_json(1) + markovify.Chain.from_json(1) # type: ignore - def test_custom_regex(self): + def test_custom_regex(self) -> None: with self.assertRaises(Exception): markovify.NewlineText( "This sentence contains a custom bad character: #.", reject_reg=r"#" @@ -187,7 +191,7 @@ class MarkovifyTestCompiled(MarkovifyTestBase): sherlock_model_ss2 = (markovify.Text(sherlock_text, state_size=2)).compile() sherlock_model_ss3 = (markovify.Text(sherlock_text, state_size=3)).compile() - def test_recompiling(self): + def test_recompiling(self) -> None: model_recompile = self.sherlock_model.compile() sent = model_recompile.make_sentence() assert len(sent) != 0 diff --git a/test/test_combine.py b/test/test_combine.py index 21b1e05..a20a5d0 100644 --- a/test/test_combine.py +++ b/test/test_combine.py @@ -16,74 +16,74 @@ def get_sorted(chain_json): class MarkovifyTest(unittest.TestCase): - def test_simple(self): + def test_simple(self) -> None: text_model = sherlock_model combo = markovify.combine([text_model, text_model], [0.5, 0.5]) assert combo.chain.model == text_model.chain.model - def test_double_weighted(self): + def test_double_weighted(self) -> None: text_model = sherlock_model combo = markovify.combine([text_model, text_model]) assert combo.chain.model != text_model.chain.model - def test_combine_chains(self): + def test_combine_chains(self) -> None: chain = sherlock_model.chain markovify.combine([chain, chain]) - def test_combine_dicts(self): + def test_combine_dicts(self) -> None: _dict = sherlock_model.chain.model markovify.combine([_dict, _dict]) - def test_combine_lists(self): + def test_combine_lists(self) -> None: _list = list(sherlock_model.chain.model.items()) markovify.combine([_list, _list]) - def test_bad_types(self): + def test_bad_types(self) -> None: with self.assertRaises(Exception): markovify.combine(["testing", "testing"]) - def test_bad_weights(self): + def test_bad_weights(self) -> None: with self.assertRaises(Exception): text_model = sherlock_model markovify.combine([text_model, text_model], [0.5]) - def test_mismatched_state_sizes(self): + def test_mismatched_state_sizes(self) -> None: with self.assertRaises(Exception): text_model_a = markovify.Text(sherlock, state_size=2) text_model_b = markovify.Text(sherlock, state_size=3) markovify.combine([text_model_a, text_model_b]) - def test_mismatched_model_types(self): + def test_mismatched_model_types(self) -> None: with self.assertRaises(Exception): text_model_a = sherlock_model text_model_b = markovify.NewlineText(sherlock) markovify.combine([text_model_a, text_model_b]) - def test_compiled_model_fail(self): + def test_compiled_model_fail(self) -> None: with self.assertRaises(Exception): model_a = sherlock_model model_b = sherlock_model_compiled markovify.combine([model_a, model_b]) - def test_compiled_chain_fail(self): + def test_compiled_chain_fail(self) -> None: with self.assertRaises(Exception): model_a = sherlock_model.chain model_b = sherlock_model_compiled.chain markovify.combine([model_a, model_b]) - def test_combine_no_retain(self): + def test_combine_no_retain(self) -> None: text_model = sherlock_model_no_retain combo = markovify.combine([text_model, text_model]) assert not combo.retain_original - def test_combine_retain_on_no_retain(self): + def test_combine_retain_on_no_retain(self) -> None: text_model_a = sherlock_model_no_retain text_model_b = sherlock_model combo = markovify.combine([text_model_a, text_model_b]) assert combo.retain_original assert combo.parsed_sentences == text_model_b.parsed_sentences - def test_combine_no_retain_on_retain(self): + def test_combine_no_retain_on_retain(self) -> None: text_model_a = sherlock_model_no_retain text_model_b = sherlock_model combo = markovify.combine([text_model_b, text_model_a]) diff --git a/test/test_itertext.py b/test/test_itertext.py index 0ff46b9..af5ddf0 100644 --- a/test/test_itertext.py +++ b/test/test_itertext.py @@ -4,14 +4,14 @@ class MarkovifyTest(unittest.TestCase): - def test_simple(self): + def test_simple(self) -> None: with open(os.path.join(os.path.dirname(__file__), "texts/sherlock.txt")) as f: sherlock_model = markovify.Text(f) sent = sherlock_model.make_sentence() assert sent is not None assert len(sent) != 0 - def test_without_retaining(self): + def test_without_retaining(self) -> None: with open( os.path.join(os.path.dirname(__file__), "texts/senate-bills.txt"), encoding="utf-8", @@ -21,7 +21,7 @@ def test_without_retaining(self): assert sent is not None assert len(sent) != 0 - def test_from_json_without_retaining(self): + def test_from_json_without_retaining(self) -> None: with open( os.path.join(os.path.dirname(__file__), "texts/senate-bills.txt"), encoding="utf-8", @@ -33,7 +33,7 @@ def test_from_json_without_retaining(self): assert sent is not None assert len(sent) != 0 - def test_from_mult_files_without_retaining(self): + def test_from_mult_files_without_retaining(self) -> None: models = [] for dirpath, _, filenames in os.walk( os.path.join(os.path.dirname(__file__), "texts")