diff --git a/docs/source/en/internal/generation_utils.mdx b/docs/source/en/internal/generation_utils.mdx index bdb6c7c59ce3..3eb54d312c20 100644 --- a/docs/source/en/internal/generation_utils.mdx +++ b/docs/source/en/internal/generation_utils.mdx @@ -240,6 +240,8 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] DisjunctiveConstraint +[[autodoc]] ConjunctiveDisjunctiveConstraint + [[autodoc]] ConstraintListState ## BeamSearch diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 32f865a84756..d51562da41b5 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -708,6 +708,7 @@ "Constraint", "ConstraintListState", "DisjunctiveConstraint", + "ConjunctiveDisjunctiveConstraint", "PhrasalConstraint", ] _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer", "ConstrainedBeamSearchScorer"] @@ -3294,6 +3295,7 @@ TextDatasetForNextSentencePrediction, ) from .generation_beam_constraints import ( + ConjunctiveDisjunctiveConstraint, Constraint, ConstraintListState, DisjunctiveConstraint, diff --git a/src/transformers/generation_beam_constraints.py b/src/transformers/generation_beam_constraints.py index baf7e3b71e3e..3af65e14f3fb 100644 --- a/src/transformers/generation_beam_constraints.py +++ b/src/transformers/generation_beam_constraints.py @@ -1,5 +1,8 @@ +import collections +import copy +import itertools from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Iterable, List, Optional, Union class Constraint(ABC): @@ -201,61 +204,218 @@ def copy(self, stateful=False): return new_constraint -class DisjunctiveTrie: - def __init__(self, nested_token_ids: List[List[int]], no_subsets=True): - r""" - A helper class that builds a trie with the words represented in `nested_token_ids`. - """ - self.max_height = max([len(one) for one in nested_token_ids]) - - root = dict() - for token_ids in nested_token_ids: - level = root - for tidx, token_id in enumerate(token_ids): - if token_id not in level: - level[token_id] = dict() - - level = level[token_id] - - if no_subsets and self.has_subsets(root, nested_token_ids): - raise ValueError( - "Each list in `nested_token_ids` can't be a complete subset of another list, but is" - f" {nested_token_ids}." - ) +class DisjointSet(object): + r""" + A helper class that maintains a disjoint-set. + """ - self.trie = root + def __init__(self): + self.root = {} + self.roots = set() + self.set = {} + self.count = {} + + def __getitem__(self, node: int) -> int: + if self.root[node] != self.root[self.root[node]]: + self.root[node] = self[self.root[node]] + return self.root[node] + + def __ior__(self, nodes: Iterable[int]): + roots = {self[node] for node in nodes} + if 1 < len(roots): + new_root = roots.pop() + for root in roots: + self.root[root] = new_root + self.roots.remove(root) + self.set[new_root] |= self.set[root] + self.count[new_root] += self.count[root] + return self + + +class ACAutomaton(object): + r""" + A helper class that builds an AC Automaton with the words represented in `force_words_ids`. The Aho-Corasick + algorithm is adapted to handle both conjunctive and disjunctive cases at the same time. Disjoint Set is used to + handle the edge cases where disjuncitve constraints share words. + """ - def next_tokens(self, current_seq): - """ - The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`. - """ - start = self.trie + def __init__( + self, force_words_ids: Iterable[Union[Iterable[int], Iterable[Iterable[int]]]], no_subsets: bool = True + ): + self.current_node = 0 + self.unfulfilled = 0 + + self.trie = [{}] + self.parent = [(0, 0)] + self.conj_count = [0] + self.height = [0] + self.disj_set = [set()] + + self.disj_constraints = [] + self.disj_max_height = [] + + for word_ids in map(iter, force_words_ids): + try: + word_id = next(word_ids) + except StopIteration: + continue + word_ids = itertools.chain((word_id,), word_ids) + if isinstance(word_id, int): + word_ids = (word_ids,) + elif not no_subsets: + word_ids, ac_automaton = [], ACAutomaton(word_ids) + + nodes = [(0,)] + while nodes: + *token_ids, root = nodes.pop() + + if ac_automaton.output[root] or ac_automaton.conj_count[root]: + word_ids.append(token_ids) + else: + for token_id, node in ac_automaton.trie[root].items(): + nodes.append((*token_ids, token_id, node)) + + nodes = set(map(self.insert_leaf, word_ids)) + nodes.discard(0) + + if len(nodes) == 1: + (node,) = nodes + self.conj_count[node] += 1 + self.unfulfilled += self.height[node] + elif nodes: + for node in nodes: + self.disj_set[node].add(len(self.disj_constraints)) + self.disj_constraints.append(nodes) + max_height = max(self.height[node] for node in nodes) + self.disj_max_height.append(max_height) + self.unfulfilled += max_height + + self.disj_group = DisjointSet() + + self.suffix = [0] * len(self.trie) + self.output = [0] * len(self.trie) + + nodes = collections.deque(self.trie[0].values()) + while nodes: + root = nodes.popleft() + + suffix = self.suffix[root] + self.output[root] = suffix if self.conj_count[suffix] or self.disj_set[suffix] else self.output[suffix] + + for token_id, node in self.trie[root].items(): + self.suffix[node] = self.does_advance(suffix, token_id) + nodes.append(node) + + def insert_leaf(self, token_ids: Iterable[int]) -> int: + node = 0 + for height, token_id in enumerate(token_ids, 1): + if token_id not in self.trie[node]: + self.trie[node][token_id] = len(self.trie) + self.trie.append({}) + self.parent.append((node, token_id)) + self.conj_count.append(0) + self.height.append(height) + self.disj_set.append(set()) + node = self.trie[node][token_id] + return node + + def is_invalid(self, node: int) -> bool: + return node and self.conj_count[node] + len(self.disj_set[node]) == 0 and not self.trie[node] + + def delete_from_trie(self, node: int): + while self.is_invalid(node): + node, token_id = self.parent[node] + del self.trie[node][token_id] + + def try_delete_leaf(self, node: int): + if 0 < self.conj_count[node] + len(self.disj_set[node]): + self.conj_count[node] -= 1 + self.unfulfilled -= self.height[node] + + if self.conj_count[node] == 0 and not self.disj_set[node]: + self.delete_from_trie(node) + elif self.conj_count[node] < 0: + if node in self.disj_group.root: + node = self.disj_group[node] + self.disj_group.count[node] += 1 + else: + self.disj_group.root[node] = node + self.disj_group.roots.add(node) + self.disj_group.count[node] = 1 + self.disj_group.set[node] = set(self.disj_set[node]) + + while True: + for root in self.disj_group.roots: + if ( + root != node + and len(self.disj_group.set[root] - self.disj_group.set[node]) + <= self.disj_group.count[root] + and len(self.disj_group.set[node] - self.disj_group.set[root]) + <= self.disj_group.count[node] + ): + self.disj_group |= (root, node) + node = self.disj_group[node] + break + else: + break + + if len(self.disj_group.set[node]) == self.disj_group.count[node]: + self.unfulfilled -= sum(self.disj_max_height[index] for index in self.disj_group.set[node]) + + for leaf in set.union(*map(self.disj_constraints.__getitem__, self.disj_group.set[node])): + self.disj_set[leaf] -= self.disj_group.set[node] + + if self.conj_count[leaf] <= 0 and not self.disj_set[leaf]: + self.unfulfilled -= self.conj_count[leaf] * self.height[leaf] + self.conj_count[leaf] = 0 + self.delete_from_trie(leaf) + + self.disj_group.roots.remove(node) + + for root in self.disj_group.roots: + self.disj_group.set[root] -= self.disj_group.set[node] + + def try_delete(self, node: int, min_height: int): + while node and min_height <= self.height[node]: + if 0 < self.conj_count[node] + len(self.disj_set[node]): + output = node + else: + while self.is_invalid(self.output[node]): + self.output[node] = self.output[self.output[node]] + output = self.output[node] + + node = self.parent[node][0] + + if self.height[output] < min_height: + min_height -= 1 + else: + self.try_delete_leaf(output) + min_height = self.height[output] + + def does_advance(self, node: int, token_id: int) -> int: + while node and token_id not in self.trie[node]: + while self.is_invalid(self.suffix[node]): + self.suffix[node] = self.suffix[self.suffix[node]] + node = self.suffix[node] + return self.trie[node].get(token_id, 0) - for current_token in current_seq: - start = start[current_token] + def update(self, token_id: int): + node = self.current_node - next_tokens = list(start.keys()) + root = node + node = self.does_advance(node, token_id) + self.try_delete(root, self.height[node]) - return next_tokens + root = node + while node and not self.trie[node]: + node = self.suffix[node] + if root != node: + self.try_delete(root, self.height[node]) - def reached_leaf(self, current_seq): - next_tokens = self.next_tokens(current_seq) + self.current_node = node - return len(next_tokens) == 0 - - def count_leaves(self, root): - next_nodes = list(root.values()) - if len(next_nodes) == 0: - return 1 - else: - return sum([self.count_leaves(nn) for nn in next_nodes]) - - def has_subsets(self, trie, nested_token_ids): - """ - Returns whether # of leaves == # of words. Otherwise some word is a subset of another. - """ - leaf_count = self.count_leaves(trie) - return len(nested_token_ids) != leaf_count + def remaining(self) -> int: + return self.unfulfilled - self.height[self.current_node] class DisjunctiveConstraint(Constraint): @@ -282,69 +442,111 @@ def __init__(self, nested_token_ids: List[List[int]]): f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}." ) - self.trie = DisjunctiveTrie(nested_token_ids) + self.ac_automaton = ACAutomaton([nested_token_ids]) self.token_ids = nested_token_ids - self.seqlen = self.trie.max_height + self.seqlen = self.remaining() self.current_seq = [] - self.completed = False - def advance(self): - token_list = self.trie.next_tokens(self.current_seq) + for node in range(len(self.ac_automaton.trie)): + if self.ac_automaton.trie[node] and (self.ac_automaton.output[node] or self.ac_automaton.disj_set[node]): + raise ValueError( + "Each list in `nested_token_ids` can't be a non-suffix complete subset of another list, but is" + f" {nested_token_ids}." + ) - if len(token_list) == 0: - return None - else: - return token_list + def advance(self): + return list(self.ac_automaton.trie[self.ac_automaton.current_node]) or None def does_advance(self, token_id: int): if not isinstance(token_id, int): raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}") - next_tokens = self.trie.next_tokens(self.current_seq) - - return token_id in next_tokens + return self.ac_automaton.does_advance(self.ac_automaton.current_node, token_id) != 0 def update(self, token_id: int): if not isinstance(token_id, int): raise ValueError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}") - stepped = False - completed = False - reset = False + self.ac_automaton.update(token_id) + self.current_seq.append(token_id) - if self.does_advance(token_id): - self.current_seq.append(token_id) - stepped = True - else: - reset = True - self.reset() + stepped = self.remaining() < self.seqlen + completed = self.completed + reset = self.remaining() == self.seqlen - completed = self.trie.reached_leaf(self.current_seq) - self.completed = completed + if reset: + self.reset() return stepped, completed, reset def reset(self): - self.completed = False - self.current_seq = [] + if self.ac_automaton.unfulfilled == self.seqlen: + self.ac_automaton.current_node = 0 + else: + self.ac_automaton = ACAutomaton([self.token_ids]) + self.current_seq.clear() def remaining(self): - if self.completed: - # since this can be completed without reaching max height - return 0 - else: - return self.seqlen - len(self.current_seq) + return self.ac_automaton.remaining() + + @property + def completed(self): + return self.remaining() == 0 def copy(self, stateful=False): - new_constraint = DisjunctiveConstraint(self.token_ids) + return copy.deepcopy(self) if stateful else type(self)(self.token_ids) - if stateful: - new_constraint.seq_len = self.seqlen - new_constraint.current_seq = self.current_seq - new_constraint.completed = self.completed - return new_constraint +class ConjunctiveDisjunctiveConstraint(Constraint): + r""" + A special [`Constraint`] that is fulfilled by fulfilling a series of conjunctive and disjunctive constraints. It + handles multiple constraints simultaneously, even for edge cases. + - allow [1,2,3] fulfill [[1,2],[2,3]], where words overlap + - allow [1,2,3,2,3,1] fulfill [[1],[2,3],[1,2,3]], where longest fulfills first + - allow [1,2]/[1,3] fulfill [[[1],[2]],[[1],[3]]], where condition is ambiguous + - allow [1,2]/[2] fulfill [[[1, 2, 3], [1, 2], [2]]], where redundant suffix shrinks + + Args: + force_words_ids (`Iterable[Union[Iterable[int], Iterable[Iterable[int]]]]`): + List of constraints to be fulfilled. If given `Iterable[int]`, this is treated as a positive constraint. If + given `Iterable[Iterable[int]]`, this is treated as a disjunctive positive constraint where one can allow + different forms of each word. + """ + + def __init__(self, force_words_ids: Iterable[Union[Iterable[int], Iterable[Iterable[int]]]]): + super(Constraint, self).__init__() + + self.force_words_ids = force_words_ids + self.ac_automaton = ACAutomaton(self.force_words_ids, False) + self.seqlen = self.remaining() + + def advance(self) -> Optional[List[int]]: + return list(self.ac_automaton.trie[self.ac_automaton.current_node]) or None + + def does_advance(self, token_id: int) -> bool: + return self.ac_automaton.does_advance(self.ac_automaton.current_node, token_id) != 0 + + def update(self, token_id: int): + self.ac_automaton.update(token_id) + + stepped = self.remaining() < self.seqlen + completed = self.remaining() == 0 + reset = self.remaining() == self.seqlen + + return stepped, completed, reset + + def reset(self): + if self.ac_automaton.unfulfilled == self.seqlen: + self.ac_automaton.current_node = 0 + else: + self.ac_automaton = ACAutomaton(self.force_words_ids, False) + + def remaining(self) -> int: + return self.ac_automaton.remaining() + + def copy(self, stateful: bool = False): + return copy.deepcopy(self) if stateful else type(self)(self.force_words_ids) class ConstraintListState: diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 68d742f5c3d0..45727ddbe845 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -80,6 +80,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class ConjunctiveDisjunctiveConstraint(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class Constraint(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_generation_beam_constraints.py b/tests/generation/test_generation_beam_constraints.py index 311cdc1429f3..7730d059b196 100644 --- a/tests/generation/test_generation_beam_constraints.py +++ b/tests/generation/test_generation_beam_constraints.py @@ -23,12 +23,12 @@ if is_torch_available(): import torch - from transformers.generation_beam_constraints import DisjunctiveConstraint + from transformers.generation_beam_constraints import ConjunctiveDisjunctiveConstraint, DisjunctiveConstraint @require_torch class ConstraintTest(unittest.TestCase): - def test_input_types(self): + def test_dc_input_types(self): # For consistency across different places the DisjunctiveConstraint is called, # dc.token_ids is a list of integers. It is also initialized only by integers. @@ -42,8 +42,8 @@ def test_input_types(self): with self.assertRaises(ValueError): DisjunctiveConstraint([torch.LongTensor([1, 2, 4]), torch.LongTensor([1, 2, 3, 4, 5])]) - def test_check_illegal_input(self): - # We can't have constraints that are complete subsets of another. This leads to a preverse + def test_dc_check_illegal_input(self): + # We can't have constraints that are non-suffix complete subsets of another. This leads to a preverse # interpretation of "constraint fulfillment": does generating [1,2,3] fulfill the constraint? # It would mean that it generated [1,2] which fulfills it, but it's in the middle of potentially # fulfilling [1,2,3,4]. If we believe that [1,2,3] does fulfill the constraint, then the algorithm @@ -53,63 +53,398 @@ def test_check_illegal_input(self): with self.assertRaises(ValueError): DisjunctiveConstraint(cset) # fails here - def test_example_progression(self): + cset = [[2, 3], [1, 2, 3, 4]] + + with self.assertRaises(ValueError): + DisjunctiveConstraint(cset) # fails here + + cset = [[3, 4], [1, 2, 3, 4]] + + DisjunctiveConstraint(cset) # succeeds here + + def test_dc_example_progression_and_copy(self): cset = [[1, 2, 3], [1, 2, 4]] dc = DisjunctiveConstraint(cset) + self.assertTrue(dc.does_advance(1)) stepped, completed, reset = dc.update(1) desired = stepped is True and completed is False and reset is False self.assertTrue(desired) self.assertTrue(not dc.completed) self.assertTrue(dc.current_seq == [1]) + self.assertTrue(dc.advance() == [2]) + self.assertTrue(dc.does_advance(2)) stepped, completed, reset = dc.update(2) desired = stepped is True and completed is False and reset is False self.assertTrue(desired) self.assertTrue(not dc.completed) self.assertTrue(dc.current_seq == [1, 2]) + self.assertTrue(dc.advance() == [3, 4]) + self.assertTrue(dc.remaining() == 1) + dc_copy = dc.copy() + self.assertTrue(dc_copy.remaining() == 3) + dc_copy = dc.copy(True) + self.assertTrue(dc_copy.remaining() == 1) + + self.assertTrue(not dc_copy.does_advance(5)) + stepped, completed, reset = dc_copy.update(5) + desired = stepped is False and completed is False and reset is True + self.assertTrue(desired) + self.assertTrue(not dc_copy.completed) + self.assertTrue(dc_copy.current_seq == []) # Reset! + self.assertTrue(dc_copy.advance() == [1]) + + self.assertTrue(dc.does_advance(3)) stepped, completed, reset = dc.update(3) desired = stepped is True and completed is True and reset is False self.assertTrue(desired) self.assertTrue(dc.completed) # Completed! self.assertTrue(dc.current_seq == [1, 2, 3]) + self.assertTrue(dc.advance() is None) - def test_example_progression_unequal_three_mid_and_reset(self): + def test_dc_example_progression_unequal_three_mid_and_reset(self): cset = [[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]] dc = DisjunctiveConstraint(cset) + self.assertTrue(dc.does_advance(1)) stepped, completed, reset = dc.update(1) self.assertTrue(not dc.completed) self.assertTrue(dc.current_seq == [1]) + self.assertTrue(dc.advance() == [2]) + self.assertTrue(dc.does_advance(2)) stepped, completed, reset = dc.update(2) self.assertTrue(not dc.completed) self.assertTrue(dc.current_seq == [1, 2]) + self.assertTrue(dc.advance() == [3, 4, 5]) + self.assertTrue(dc.does_advance(4)) stepped, completed, reset = dc.update(4) self.assertTrue(not dc.completed) self.assertTrue(dc.current_seq == [1, 2, 4]) + self.assertTrue(dc.advance() == [5]) + self.assertTrue(dc.does_advance(5)) stepped, completed, reset = dc.update(5) self.assertTrue(dc.completed) # Completed! self.assertTrue(dc.current_seq == [1, 2, 4, 5]) + self.assertTrue(dc.advance() is None) dc.reset() + self.assertTrue(dc.does_advance(1)) stepped, completed, reset = dc.update(1) self.assertTrue(not dc.completed) self.assertTrue(dc.remaining() == 3) self.assertTrue(dc.current_seq == [1]) + self.assertTrue(dc.advance() == [2]) + self.assertTrue(dc.does_advance(2)) stepped, completed, reset = dc.update(2) self.assertTrue(not dc.completed) self.assertTrue(dc.remaining() == 2) self.assertTrue(dc.current_seq == [1, 2]) + self.assertTrue(dc.advance() == [3, 4, 5]) + self.assertTrue(dc.does_advance(5)) stepped, completed, reset = dc.update(5) self.assertTrue(dc.completed) # Completed! self.assertTrue(dc.remaining() == 0) self.assertTrue(dc.current_seq == [1, 2, 5]) + self.assertTrue(dc.advance() is None) + + def test_dc_example_progression_mid_overlap_two(self): + cset = [[1, 2, 3], [2, 4]] + + dc = DisjunctiveConstraint(cset) + + self.assertTrue(dc.does_advance(1)) + stepped, completed, reset = dc.update(1) + self.assertTrue(not dc.completed) + self.assertTrue(dc.current_seq == [1]) + self.assertTrue(dc.advance() == [2]) + + self.assertTrue(dc.does_advance(2)) + stepped, completed, reset = dc.update(2) + self.assertTrue(not dc.completed) + self.assertTrue(dc.current_seq == [1, 2]) + self.assertTrue(dc.advance() == [3]) + + self.assertTrue(dc.does_advance(4)) + stepped, completed, reset = dc.update(4) + self.assertTrue(dc.completed) # Completed! + self.assertTrue(dc.current_seq == [1, 2, 4]) + self.assertTrue(dc.advance() is None) + + def test_cdc_input(self): + cset = [[], [[]], [[], []], [1], [[1]], [[1], []], [[1], [1]], [[1], [1], []], [[1], [1, 2], [0, 1, 2]]] + + cdc = ConjunctiveDisjunctiveConstraint(cset) # succeeds here + + self.assertTrue(cdc.remaining() == 7) + self.assertTrue(len(cdc.ac_automaton.trie) == 4) + self.assertTrue(cdc.ac_automaton.conj_count[1] == 5) + self.assertTrue(cdc.ac_automaton.disj_set[1] == {0}) + self.assertTrue(cdc.ac_automaton.conj_count[3] == 0) + self.assertTrue(cdc.ac_automaton.disj_set[3] == {0}) + + def test_cdc_example_progression_and_copy(self): + cset = [[[1, 2, 3], [1, 2, 4]]] + + cdc = ConjunctiveDisjunctiveConstraint(cset) + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + desired = stepped is True and completed is False and reset is False + self.assertTrue(desired) + self.assertTrue(cdc.advance() == [2]) + + self.assertTrue(cdc.does_advance(2)) + stepped, completed, reset = cdc.update(2) + desired = stepped is True and completed is False and reset is False + self.assertTrue(desired) + self.assertTrue(sorted(cdc.advance()) == [3, 4]) + + self.assertTrue(cdc.remaining() == 1) + cdc_copy = cdc.copy() + self.assertTrue(cdc_copy.remaining() == 3) + cdc_copy = cdc.copy(True) + self.assertTrue(cdc_copy.remaining() == 1) + + self.assertTrue(not cdc_copy.does_advance(5)) + stepped, completed, reset = cdc_copy.update(5) + desired = stepped is False and completed is False and reset is True + self.assertTrue(desired) # Reset! + self.assertTrue(cdc_copy.advance() == [1]) + + self.assertTrue(cdc.does_advance(3)) + stepped, completed, reset = cdc.update(3) + desired = stepped is True and completed is True and reset is False + self.assertTrue(desired) # Completed! + self.assertTrue(cdc.advance() is None) + + def test_cdc_example_progression_unequal_three_mid_and_reset(self): + cset = [[[1, 2, 3], [1, 2, 4, 5], [1, 2, 5]]] + + cdc = ConjunctiveDisjunctiveConstraint(cset) + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + self.assertTrue(not completed) + self.assertTrue(cdc.advance() == [2]) + + self.assertTrue(cdc.does_advance(2)) + stepped, completed, reset = cdc.update(2) + self.assertTrue(not completed) + self.assertTrue(sorted(cdc.advance()) == [3, 4, 5]) + + self.assertTrue(cdc.does_advance(4)) + stepped, completed, reset = cdc.update(4) + self.assertTrue(not completed) + self.assertTrue(cdc.advance() == [5]) + + self.assertTrue(cdc.does_advance(5)) + stepped, completed, reset = cdc.update(5) + self.assertTrue(completed) # Completed! + self.assertTrue(cdc.advance() is None) + + cdc.reset() + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 3) + self.assertTrue(cdc.advance() == [2]) + + self.assertTrue(cdc.does_advance(2)) + stepped, completed, reset = cdc.update(2) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 2) + self.assertTrue(sorted(cdc.advance()) == [3, 4, 5]) + + self.assertTrue(cdc.does_advance(5)) + stepped, completed, reset = cdc.update(5) + self.assertTrue(completed) # Completed! + self.assertTrue(cdc.remaining() == 0) + self.assertTrue(cdc.advance() is None) + + def test_cdc_example_progression_mid_overlap_two(self): + cset = [[[1, 2, 3], [2, 4]]] + + cdc = ConjunctiveDisjunctiveConstraint(cset) + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + self.assertTrue(not completed) + self.assertTrue(cdc.advance() == [2]) + + self.assertTrue(cdc.does_advance(2)) + stepped, completed, reset = cdc.update(2) + self.assertTrue(not completed) + self.assertTrue(cdc.advance() == [3]) + + self.assertTrue(cdc.does_advance(4)) + stepped, completed, reset = cdc.update(4) + self.assertTrue(completed) # Completed! + self.assertTrue(cdc.advance() is None) + + def test_cdc_example_progression_loop_three(self): + cset = [[[1], [2]], [[2], [3]], [[3], [1]]] + + cdc = ConjunctiveDisjunctiveConstraint(cset) + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 2) + self.assertTrue(sorted(cdc.advance()) == [1, 2, 3]) + + cdc_copy = cdc.copy(True) + self.assertTrue(cdc_copy.does_advance(1)) + stepped, completed, reset = cdc_copy.update(1) + self.assertTrue(not completed) + self.assertTrue(cdc_copy.remaining() == 1) + self.assertTrue(cdc_copy.advance() == [2, 3]) + + self.assertTrue(cdc_copy.does_advance(2)) + stepped, completed, reset = cdc_copy.update(2) + self.assertTrue(completed) # Completed! + self.assertTrue(cdc_copy.remaining() == 0) + self.assertTrue(cdc_copy.advance() is None) + + self.assertTrue(cdc.does_advance(2)) + stepped, completed, reset = cdc.update(2) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 1) + self.assertTrue(sorted(cdc.advance()) == [1, 2, 3]) + + cdc_copy = cdc.copy(True) + self.assertTrue(cdc_copy.does_advance(1)) + stepped, completed, reset = cdc_copy.update(1) + self.assertTrue(completed) # Completed! + self.assertTrue(cdc_copy.remaining() == 0) + self.assertTrue(cdc_copy.advance() is None) + + self.assertTrue(cdc.does_advance(3)) + stepped, completed, reset = cdc.update(3) + self.assertTrue(completed) # Completed! + self.assertTrue(cdc.remaining() == 0) + self.assertTrue(cdc.advance() is None) + + def test_cdc_example_progression_overlap_four(self): + cset = [[1, 2, 3, 4, 5], [1], [3, 4], [4, 1]] + + cdc = ConjunctiveDisjunctiveConstraint(cset) + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 9) + self.assertTrue(cdc.advance() == [2]) + + self.assertTrue(cdc.does_advance(2)) + stepped, completed, reset = cdc.update(2) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 8) + self.assertTrue(cdc.advance() == [3]) + + self.assertTrue(cdc.does_advance(3)) + stepped, completed, reset = cdc.update(3) + self.assertTrue(cdc.remaining() == 7) + self.assertTrue(cdc.advance() == [4]) + + self.assertTrue(cdc.does_advance(4)) + stepped, completed, reset = cdc.update(4) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 6) + self.assertTrue(cdc.advance() == [5]) + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + desired = stepped is True and completed is False and reset is False + self.assertTrue(desired) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 4) + self.assertTrue(cdc.advance() == [2]) + + self.assertTrue(cdc.does_advance(2)) + stepped, completed, reset = cdc.update(2) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 3) + self.assertTrue(cdc.advance() == [3]) + + self.assertTrue(cdc.does_advance(3)) + stepped, completed, reset = cdc.update(3) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 2) + self.assertTrue(cdc.advance() == [4]) + + self.assertTrue(cdc.does_advance(4)) + stepped, completed, reset = cdc.update(4) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 1) + self.assertTrue(cdc.advance() == [5]) + + self.assertTrue(cdc.does_advance(5)) + stepped, completed, reset = cdc.update(5) + self.assertTrue(completed) # Completed! + self.assertTrue(cdc.remaining() == 0) + self.assertTrue(cdc.advance() is None) + + def test_cdc_example_progression_ambiguous_eight(self): + cset = [[1], [[1], [2]], [[1], [2], [3]], [[3], [4]], [[3], [5]], [[3], [6], [7]], [[7], [8]], [4]] + + cdc = ConjunctiveDisjunctiveConstraint(cset) + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 7) + self.assertTrue(sorted(cdc.advance()) == [1, 2, 3, 4, 5, 6, 7, 8]) + + self.assertTrue(cdc.does_advance(1)) + stepped, completed, reset = cdc.update(1) + desired = stepped is True and completed is False and reset is False + self.assertTrue(desired) + self.assertTrue(cdc.remaining() == 6) + self.assertTrue(sorted(cdc.advance()) == [1, 2, 3, 4, 5, 6, 7, 8]) + + self.assertTrue(cdc.does_advance(2)) + stepped, completed, reset = cdc.update(2) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 5) + self.assertTrue(sorted(cdc.advance()) == [3, 4, 5, 6, 7, 8]) + + self.assertTrue(cdc.does_advance(3)) + stepped, completed, reset = cdc.update(3) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 4) + self.assertTrue(sorted(cdc.advance()) == [3, 4, 5, 6, 7, 8]) + + self.assertTrue(cdc.does_advance(3)) + stepped, completed, reset = cdc.update(3) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 3) + self.assertTrue(sorted(cdc.advance()) == [3, 4, 5, 6, 7, 8]) + + self.assertTrue(cdc.does_advance(6)) + stepped, completed, reset = cdc.update(6) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 2) + self.assertTrue(sorted(cdc.advance()) == [4, 7, 8]) + + self.assertTrue(cdc.does_advance(4)) + stepped, completed, reset = cdc.update(4) + self.assertTrue(not completed) + self.assertTrue(cdc.remaining() == 1) + self.assertTrue(sorted(cdc.advance()) == [7, 8]) + + self.assertTrue(cdc.does_advance(7)) + stepped, completed, reset = cdc.update(7) + self.assertTrue(completed) # Completed! + self.assertTrue(cdc.remaining() == 0) + self.assertTrue(cdc.advance() is None)