-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnd_bitmap_set.py
65 lines (54 loc) · 1.84 KB
/
nd_bitmap_set.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from functools import reduce
from operator import mul
from typing import Optional
from bitmap_set import BitmapSet, Bounds, Elem, Elems
class NdBitmapSet(BitmapSet):
def __init__(self, bounds: Bounds, elems: Optional[Elems] = None) -> None:
self._shape = bounds
super().__init__(bounds, elems)
@classmethod
def _validate_bounds(cls, bounds: Bounds) -> None:
if isinstance(bounds, tuple):
if all(isinstance(n, int) for n in bounds):
if all(n > 0 for n in bounds):
pass
else:
raise ValueError
else:
raise TypeError
else:
raise TypeError
def _get_size(self, bounds: Bounds) -> int:
return reduce(mul, bounds)
@property
def shape(self) -> tuple[int]:
return self._shape
def _validate_elem(self, elem: Elem) -> None:
if isinstance(elem, tuple):
if len(elem) == len(self.shape) and all(
isinstance(v, int) for v in elem):
if all(0 <= v < n for v, n in zip(elem, self.shape)):
pass
else:
raise ValueError
else:
raise TypeError
else:
raise TypeError
def _hash(self, elem: Elem) -> int:
i = 0
factor = 1
for v, n in zip(reversed(elem), reversed(self.shape)):
i += v * factor
factor *= n
return i
def _unhash(self, i: int) -> Elem:
elem = [0 for _ in self.shape]
for v_i, n in enumerate(reversed(self.shape)):
elem[v_i] = i % n
i //= n
return tuple(reversed(elem))
def _validate_other(self, other):
super()._validate_other(other)
if other.shape != self.shape:
raise ValueError