Skip to content

Commit 52961f0

Browse files
desertaxlezzstoatzz
andcommitted
Another way to handle multiple Inputs cache policies (#16786)
Co-authored-by: nate nowack <[email protected]>
1 parent 28f7cdf commit 52961f0

File tree

2 files changed

+40
-31
lines changed

2 files changed

+40
-31
lines changed

src/prefect/cache_policies.py

+27-29
Original file line numberDiff line numberDiff line change
@@ -157,39 +157,15 @@ class CompoundCachePolicy(CachePolicy):
157157

158158
policies: list[CachePolicy] = field(default_factory=list)
159159

160-
def __sub__(self, other: str) -> CachePolicy:
161-
if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance]
162-
raise TypeError("Can only subtract strings from key policies.")
163-
164-
# Subtract from each sub-policy; this may stack new Inputs(...) layers
165-
new_subpolicies: list[CachePolicy] = [p - other for p in self.policies]
166-
167-
# If there are multiple Inputs, unify them all into one
168-
inputs_policies = [p for p in new_subpolicies if isinstance(p, Inputs)]
160+
def __post_init__(self) -> None:
161+
# deduplicate any Inputs policies
162+
inputs_policies = [p for p in self.policies if isinstance(p, Inputs)]
163+
self.policies = [p for p in self.policies if not isinstance(p, Inputs)]
169164
if inputs_policies:
170-
# Gather all excludes into a single set
171165
all_excludes: set[str] = set()
172166
for inputs_policy in inputs_policies:
173167
all_excludes.update(inputs_policy.exclude)
174-
175-
# Remove all old Inputs from the subpolicy list
176-
new_subpolicies = [p for p in new_subpolicies if not isinstance(p, Inputs)]
177-
178-
# Append one merged Inputs that excludes the union
179-
merged_inputs = Inputs(exclude=sorted(all_excludes))
180-
new_subpolicies.append(merged_inputs)
181-
182-
# If no sub‐policies exist after this, create an Inputs with [other]
183-
# (to preserve the "always add an Inputs" behavior)
184-
if not new_subpolicies:
185-
new_subpolicies = [Inputs(exclude=[other])]
186-
187-
return CompoundCachePolicy(
188-
policies=new_subpolicies,
189-
key_storage=self.key_storage,
190-
isolation_level=self.isolation_level,
191-
lock_manager=self.lock_manager,
192-
)
168+
self.policies.append(Inputs(exclude=sorted(all_excludes)))
193169

194170
def compute_key(
195171
self,
@@ -212,6 +188,28 @@ def compute_key(
212188
return None
213189
return hash_objects(*keys, raise_on_failure=True)
214190

191+
def __add__(self, other: "CachePolicy") -> "CachePolicy":
192+
# Call the superclass add method to handle validation
193+
super().__add__(other)
194+
195+
if isinstance(other, CompoundCachePolicy):
196+
policies = [*self.policies, *other.policies]
197+
else:
198+
policies = [*self.policies, other]
199+
200+
return CompoundCachePolicy(
201+
policies=policies,
202+
key_storage=self.key_storage or other.key_storage,
203+
isolation_level=self.isolation_level or other.isolation_level,
204+
lock_manager=self.lock_manager or other.lock_manager,
205+
)
206+
207+
def __sub__(self, other: str) -> "CachePolicy":
208+
if not isinstance(other, str): # type: ignore[reportUnnecessaryIsInstance]
209+
raise TypeError("Can only subtract strings from key policies.")
210+
new = Inputs(exclude=[other])
211+
return CompoundCachePolicy(policies=[*self.policies, new])
212+
215213

216214
@dataclass
217215
class _None(CachePolicy):

tests/test_cache_policies.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ def test_nones_are_ignored(self):
158158
)
159159
assert compound_key is None
160160

161+
def test_adding_two_compound_policies_merges_policies(self):
162+
one = CompoundCachePolicy(policies=[Inputs(), TaskSource()])
163+
two = CompoundCachePolicy(policies=[RunId()])
164+
policy = one + two
165+
assert isinstance(policy, CompoundCachePolicy)
166+
assert len(policy.policies) == 3
167+
assert Inputs() in policy.policies
168+
assert RunId() in policy.policies
169+
assert TaskSource() in policy.policies
170+
161171
def test_compound_policy_deduplicates_inputs_on_subtraction(self):
162172
"""Regression test for https://github.com/PrefectHQ/prefect/issues/16773"""
163173
# Create a compound policy with multiple Inputs policies
@@ -169,6 +179,8 @@ def test_compound_policy_deduplicates_inputs_on_subtraction(self):
169179
Inputs(exclude=["y"]),
170180
]
171181
)
182+
# Inputs get combined into a single policy
183+
assert len(policy.policies) == 2
172184

173185
# Subtract a new key
174186
new_policy = policy - "z"
@@ -184,8 +196,7 @@ def test_compound_policy_deduplicates_inputs_on_subtraction(self):
184196
# So we should have one merged Inputs policy and one CompoundCachePolicy containing TaskSource
185197
assert len(new_policy.policies) == 2
186198
assert any(
187-
isinstance(p, CompoundCachePolicy)
188-
and any(isinstance(sp, TaskSource) for sp in p.policies)
199+
isinstance(p, CompoundCachePolicy) or isinstance(p, TaskSource)
189200
for p in new_policy.policies
190201
)
191202

0 commit comments

Comments
 (0)