Skip to content

Commit 936d5f6

Browse files
authored
feat: enhance FilteredFileAdapter to handle flexible filtering for policies and roles (#360)
* feat: optimize filtered file adapter policy loading * style: standardize whitespace and formatting in filtered_file_adapter.py * feat: add test * test: improve test_load_filtered_policy_with_comments in test_filter.py * test: update test description for mixed filter
1 parent 7b64b85 commit 936d5f6

File tree

2 files changed

+192
-12
lines changed

2 files changed

+192
-12
lines changed

casbin/persist/adapters/filtered_file_adapter.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,28 @@ def load_filtered_policy(self, model, filter):
5252

5353
try:
5454
filter_value = [filter.__dict__["P"]] + [filter.__dict__["G"]]
55+
is_empty_filter = all(not f for f in filter_value) or all(
56+
all(not x.strip() for x in f) if f else True for f in filter_value
57+
)
58+
if is_empty_filter:
59+
return self.load_policy(model)
5560
except:
5661
raise RuntimeError("invalid filter type")
5762

5863
self.load_filtered_policy_file(model, filter_value, persist.load_policy_line)
5964
self.filtered = True
6065

61-
def load_filtered_policy_file(self, model, filter, hanlder):
66+
def load_filtered_policy_file(self, model, filter, handler):
6267
with open(self._file_path, "rb") as file:
63-
while True:
64-
line = file.readline()
68+
for line in file:
6569
line = line.decode().strip()
66-
if line == "\n":
70+
if not line or line == "\n":
6771
continue
68-
if not line:
69-
break
72+
7073
if filter_line(line, filter):
7174
continue
7275

73-
hanlder(line, model)
76+
handler(line, model)
7477

7578
# is_filtered returns true if the loaded policy has been filtered.
7679
def is_filtered(self):
@@ -92,10 +95,13 @@ def filter_line(line, filter):
9295
return True
9396
filter_slice = []
9497

95-
if p[0].strip() == "p":
96-
filter_slice = filter[0]
97-
elif p[0].strip() == "g":
98+
if p[0].strip() == "g":
99+
if not filter[1] or all(not x.strip() for x in filter[1]):
100+
return False
98101
filter_slice = filter[1]
102+
elif p[0].strip() == "p":
103+
filter_slice = filter[0]
104+
99105
return filter_words(p, filter_slice)
100106

101107

@@ -104,7 +110,7 @@ def filter_words(line, filter):
104110
return True
105111
skip_line = False
106112
for i, v in enumerate(filter):
107-
if len(v) > 0 and (v.strip() != line[i + 1].strip()):
113+
if v and v.strip() and (v.strip() != line[i + 1].strip()):
108114
skip_line = True
109115
break
110116

tests/test_filter.py

+175-1
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import casbin
15+
import os
1616
from unittest import TestCase
17+
import casbin
1718
from tests.test_enforcer import get_examples
1819
from casbin.persist.adapters import FilteredFileAdapter
20+
from casbin.persist.adapters.filtered_file_adapter import filter_line, filter_words
1921

2022

2123
class Filter:
@@ -141,3 +143,175 @@ def test_filtered_adapter_invalid_filepath(self):
141143

142144
with self.assertRaises(RuntimeError):
143145
e.load_filtered_policy(None)
146+
147+
def test_empty_filter_array(self):
148+
"""Test filter for empty array."""
149+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
150+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
151+
filter = Filter()
152+
filter.P = []
153+
filter.G = []
154+
155+
e.load_filtered_policy(filter)
156+
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
157+
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))
158+
159+
def test_empty_string_filter(self):
160+
"""Test the filter for all empty strings."""
161+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
162+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
163+
filter = Filter()
164+
filter.P = ["", "", ""]
165+
filter.G = ["", "", ""]
166+
167+
e.load_filtered_policy(filter)
168+
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
169+
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))
170+
171+
def test_mixed_empty_filter(self):
172+
"""Test the filter for mixed empty and non-empty strings."""
173+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
174+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
175+
filter = Filter()
176+
filter.P = ["", "domain1", ""]
177+
filter.G = ["", "", "domain1"]
178+
179+
e.load_filtered_policy(filter)
180+
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
181+
self.assertFalse(e.has_policy(["admin", "domain2", "data2", "read"]))
182+
183+
def test_nonexistent_domain_filter(self):
184+
"""Testing the filter for a non-existent domain."""
185+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
186+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
187+
filter = Filter()
188+
filter.P = ["", "domain3"]
189+
filter.G = ["", "", "domain3"]
190+
191+
e.load_filtered_policy(filter)
192+
self.assertFalse(e.has_policy(["admin", "domain3", "data3", "read"]))
193+
194+
def test_empty_filter_array(self):
195+
"""Test filter for empty array."""
196+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
197+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
198+
filter = Filter()
199+
filter.P = []
200+
filter.G = []
201+
202+
try:
203+
e.load_filtered_policy(filter)
204+
except:
205+
raise RuntimeError("unexpected error with empty filter arrays")
206+
207+
self.assertFalse(e.is_filtered(), "Adapter should not be marked as filtered with empty filters")
208+
209+
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
210+
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))
211+
212+
def test_empty_string_filter(self):
213+
"""Test the filter for all empty strings."""
214+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
215+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
216+
filter = Filter()
217+
filter.P = ["", "", ""]
218+
filter.G = ["", "", ""]
219+
220+
try:
221+
e.load_filtered_policy(filter)
222+
except:
223+
raise RuntimeError("unexpected error with empty string filters")
224+
225+
self.assertFalse(e.is_filtered(), "Adapter should not be marked as filtered with empty string filters")
226+
227+
try:
228+
e.save_policy()
229+
except:
230+
raise RuntimeError("unexpected error in SavePolicy with empty string filters")
231+
232+
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
233+
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))
234+
235+
def test_mixed_empty_filter(self):
236+
"""Test the filter for mixed empty and non-empty strings."""
237+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
238+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
239+
filter = Filter()
240+
filter.P = ["", "domain1", ""]
241+
filter.G = ["", "", "domain1"]
242+
243+
try:
244+
e.load_filtered_policy(filter)
245+
except:
246+
raise RuntimeError("unexpected error with mixed empty filters")
247+
248+
self.assertTrue(e.is_filtered(), "Adapter should be marked as filtered")
249+
250+
with self.assertRaises(RuntimeError):
251+
e.save_policy()
252+
253+
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
254+
self.assertFalse(e.has_policy(["admin", "domain2", "data2", "read"]))
255+
256+
def test_whitespace_filter(self):
257+
"""Test the filter for all blank characters."""
258+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
259+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
260+
filter = Filter()
261+
filter.P = [" ", " ", "\t"]
262+
filter.G = ["\n", " ", " "]
263+
264+
e.load_filtered_policy(filter)
265+
266+
self.assertFalse(e.is_filtered())
267+
self.assertTrue(e.has_policy(["admin", "domain1", "data1", "read"]))
268+
self.assertTrue(e.has_policy(["admin", "domain2", "data2", "read"]))
269+
270+
def test_filter_line_edge_cases(self):
271+
"""Test the boundary cases of the filter_line function."""
272+
adapter = FilteredFileAdapter(get_examples("rbac_with_domains_policy.csv"))
273+
274+
self.assertFalse(filter_line("", [[""], [""]]))
275+
276+
self.assertFalse(filter_line("invalid_line", [[""], [""]]))
277+
278+
self.assertFalse(filter_line("p, admin, domain1, data1, read", None))
279+
280+
def test_filter_words_edge_cases(self):
281+
"""Test the boundary cases of the filter_words function."""
282+
self.assertTrue(filter_words(["p"], ["filter1", "filter2"]))
283+
284+
self.assertFalse(filter_words(["p", "admin", "domain1"], []))
285+
286+
line = ["admin", "domain1", "data*", "read"]
287+
filter = ["", "", "data1", ""]
288+
self.assertTrue(filter_words(line, filter))
289+
290+
def test_load_filtered_policy_with_comments(self):
291+
"""Test loading filtering policies with comments."""
292+
import tempfile
293+
import shutil
294+
295+
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
296+
with open(get_examples("rbac_with_domains_policy.csv"), "r") as source:
297+
shutil.copyfileobj(source, temp_file)
298+
299+
temp_file.write("\n# This is a comment\np, admin, domain1, data3, read")
300+
temp_file.flush()
301+
302+
temp_path = temp_file.name
303+
304+
try:
305+
adapter = FilteredFileAdapter(temp_path)
306+
e = casbin.Enforcer(get_examples("rbac_with_domains_model.conf"), adapter)
307+
filter = Filter()
308+
filter.P = ["", "domain1"]
309+
filter.G = ["", "", "domain1"]
310+
311+
e.load_filtered_policy(filter)
312+
self.assertTrue(e.has_policy(["admin", "domain1", "data3", "read"]))
313+
finally:
314+
try:
315+
os.unlink(temp_path)
316+
except OSError:
317+
pass

0 commit comments

Comments
 (0)