Skip to content

Commit 3483c2d

Browse files
radoeringbostonrwalker
authored andcommitted
Improve constraint union (python-poetry#283)
* Refactoring in preparation to increase coverage by adding additional tests * Increase test coverage for constraint intersection * Improve constraint union
1 parent 27a7614 commit 3483c2d

File tree

2 files changed

+121
-41
lines changed

2 files changed

+121
-41
lines changed

src/poetry/core/packages/constraints/constraint.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any
44
from typing import Union
55

6+
from poetry.core.packages.constraints import AnyConstraint
67
from poetry.core.packages.constraints.base_constraint import BaseConstraint
78
from poetry.core.packages.constraints.empty_constraint import EmptyConstraint
89

@@ -106,7 +107,19 @@ def union(self, other: "BaseConstraint") -> "BaseConstraint":
106107
UnionConstraint,
107108
)
108109

109-
return UnionConstraint(self, other)
110+
if other == self:
111+
return self
112+
113+
if self.operator == "!=" and other.operator == "==" and self.allows(other):
114+
return self
115+
116+
if other.operator == "!=" and self.operator == "==" and other.allows(self):
117+
return other
118+
119+
if other.operator == "==" and self.operator == "==":
120+
return UnionConstraint(self, other)
121+
122+
return AnyConstraint()
110123

111124
return other.union(self)
112125

tests/packages/constraints/test_constraint.py

+107-40
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1+
from typing import TYPE_CHECKING
2+
3+
import pytest
4+
5+
from poetry.core.packages.constraints import AnyConstraint
16
from poetry.core.packages.constraints.constraint import Constraint
27
from poetry.core.packages.constraints.empty_constraint import EmptyConstraint
38
from poetry.core.packages.constraints.multi_constraint import MultiConstraint
49
from poetry.core.packages.constraints.union_constraint import UnionConstraint
510

611

12+
if TYPE_CHECKING:
13+
from poetry.core.packages.constraints import BaseConstraint
14+
15+
716
def test_allows():
817
c = Constraint("win32")
918

@@ -41,46 +50,104 @@ def test_allows_all():
4150
assert not c.allows_all(UnionConstraint(Constraint("win32"), Constraint("linux")))
4251

4352

44-
def test_intersect():
45-
c = Constraint("win32")
46-
47-
intersection = c.intersect(Constraint("linux"))
48-
assert intersection == EmptyConstraint()
49-
50-
intersection = c.intersect(
51-
UnionConstraint(Constraint("win32"), Constraint("linux"))
52-
)
53-
assert intersection == Constraint("win32")
54-
55-
intersection = c.intersect(
56-
UnionConstraint(Constraint("linux"), Constraint("linux2"))
57-
)
58-
assert intersection == EmptyConstraint()
59-
60-
intersection = c.intersect(Constraint("linux", "!="))
61-
assert intersection == c
62-
63-
c = Constraint("win32", "!=")
64-
65-
intersection = c.intersect(Constraint("linux", "!="))
66-
assert intersection == MultiConstraint(
67-
Constraint("win32", "!="), Constraint("linux", "!=")
68-
)
69-
70-
71-
def test_union():
72-
c = Constraint("win32")
73-
74-
union = c.union(Constraint("linux"))
75-
assert union == UnionConstraint(Constraint("win32"), Constraint("linux"))
76-
77-
union = c.union(UnionConstraint(Constraint("win32"), Constraint("linux")))
78-
assert union == UnionConstraint(Constraint("win32"), Constraint("linux"))
79-
80-
union = c.union(UnionConstraint(Constraint("linux"), Constraint("linux2")))
81-
assert union == UnionConstraint(
82-
Constraint("win32"), Constraint("linux"), Constraint("linux2")
83-
)
53+
@pytest.mark.parametrize(
54+
("constraint1", "constraint2", "expected"),
55+
[
56+
(
57+
Constraint("win32"),
58+
Constraint("win32"),
59+
Constraint("win32"),
60+
),
61+
(
62+
Constraint("win32"),
63+
Constraint("linux"),
64+
EmptyConstraint(),
65+
),
66+
(
67+
Constraint("win32"),
68+
UnionConstraint(Constraint("win32"), Constraint("linux")),
69+
Constraint("win32"),
70+
),
71+
(
72+
Constraint("win32"),
73+
UnionConstraint(Constraint("linux"), Constraint("linux2")),
74+
EmptyConstraint(),
75+
),
76+
(
77+
Constraint("win32"),
78+
Constraint("linux", "!="),
79+
Constraint("win32"),
80+
),
81+
(
82+
Constraint("win32", "!="),
83+
Constraint("linux"),
84+
Constraint("linux"),
85+
),
86+
(
87+
Constraint("win32", "!="),
88+
Constraint("linux", "!="),
89+
MultiConstraint(Constraint("win32", "!="), Constraint("linux", "!=")),
90+
),
91+
],
92+
)
93+
def test_intersect(
94+
constraint1: "BaseConstraint",
95+
constraint2: "BaseConstraint",
96+
expected: "BaseConstraint",
97+
):
98+
intersection = constraint1.intersect(constraint2)
99+
assert intersection == expected
100+
101+
102+
@pytest.mark.parametrize(
103+
("constraint1", "constraint2", "expected"),
104+
[
105+
(
106+
Constraint("win32"),
107+
Constraint("win32"),
108+
Constraint("win32"),
109+
),
110+
(
111+
Constraint("win32"),
112+
Constraint("linux"),
113+
UnionConstraint(Constraint("win32"), Constraint("linux")),
114+
),
115+
(
116+
Constraint("win32"),
117+
UnionConstraint(Constraint("win32"), Constraint("linux")),
118+
UnionConstraint(Constraint("win32"), Constraint("linux")),
119+
),
120+
(
121+
Constraint("win32"),
122+
UnionConstraint(Constraint("linux"), Constraint("linux2")),
123+
UnionConstraint(
124+
Constraint("win32"), Constraint("linux"), Constraint("linux2")
125+
),
126+
),
127+
(
128+
Constraint("win32"),
129+
Constraint("linux", "!="),
130+
Constraint("linux", "!="),
131+
),
132+
(
133+
Constraint("win32", "!="),
134+
Constraint("linux"),
135+
Constraint("win32", "!="),
136+
),
137+
(
138+
Constraint("win32", "!="),
139+
Constraint("linux", "!="),
140+
AnyConstraint(),
141+
),
142+
],
143+
)
144+
def test_union(
145+
constraint1: "BaseConstraint",
146+
constraint2: "BaseConstraint",
147+
expected: "BaseConstraint",
148+
):
149+
union = constraint1.union(constraint2)
150+
assert union == expected
84151

85152

86153
def test_difference():

0 commit comments

Comments
 (0)